<a href="https://colab.research.google.com/github/Suryanshagrawal19/Proof_Generation/blob/main/transformer%20(greedy_approach).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import torch
import torch.nn as nn
from torch.nn import Transformer
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import PreTrainedModel, PretrainedConfig, AutoTokenizer
import pandas as pd  # To load the Parquet dataset
from tqdm import tqdm  # For progress bar
from nltk.translate.bleu_score import sentence_bleu
from rouge_score import rouge_scorer
from nltk.translate.meteor_score import meteor_score

# Helper function: Compute token-level exact match accuracy.
def get_match_accuracy(generated: str, actual: str) -> float:
    """
    Computes the percentage of tokens in the actual proof that match the tokens in the generated proof (in order).
    Returns a percentage value.
    """
    gen_tokens = generated.split()
    actual_tokens = actual.split()
    if not actual_tokens:
        return 0.0
    # Count tokens that match in the same position.
    matches = sum(1 for gt, at in zip(gen_tokens, actual_tokens) if gt == at)
    return (matches / len(actual_tokens)) * 100

# 1. Load Parquet Dataset Function
def load_dataset(file_path="combined_dataset.parquet"):
    try:
        df = pd.read_parquet(file_path)
        proof_pairs = list(zip(df["Output"], df["Theorem"]))  # Adjust column names as needed.
        print(f"Loaded {len(proof_pairs)} examples from {file_path}.")
        return proof_pairs
    except Exception as e:
        print("Error loading dataset:", e)
        # Return a small dummy dataset if the file is not found.
        return [("axiom A -> B", "theorem A implies B"), ("forall x, P(x)", "universal quantifier P(x)")]

# 2. Model & Configuration Classes
class ProofTransformerConfig(PretrainedConfig):
    def __init__(self, vocab_size=30522, d_model=512, num_layers=8, nhead=8,
                 dim_feedforward=4096, dropout=0.1, **kwargs):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.num_layers = num_layers
        self.nhead = nhead
        self.dim_feedforward = dim_feedforward
        self.dropout = dropout

class ProofTransformer(PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.token_embedding = nn.Embedding(config.vocab_size, config.d_model)
        # Using a fixed maximum sequence length of 1000 for positional encoding.
        self.positional_encoding = nn.Parameter(torch.zeros(1, 1000, config.d_model))

        self.encoder = Transformer(
            d_model=config.d_model,
            nhead=config.nhead,
            num_encoder_layers=config.num_layers,
            num_decoder_layers=config.num_layers,
            dim_feedforward=config.dim_feedforward,
            dropout=config.dropout,
        )

        self.lm_head = nn.Linear(config.d_model, config.vocab_size)
        self.init_weights()

    def forward(self, input_ids, decoder_input_ids):
        src = self.token_embedding(input_ids) + self.positional_encoding[:, :input_ids.shape[1], :]
        tgt = self.token_embedding(decoder_input_ids) + self.positional_encoding[:, :decoder_input_ids.shape[1], :]

        # Transformer expects input as [sequence_length, batch_size, d_model]
        src = src.transpose(0, 1)
        tgt = tgt.transpose(0, 1)

        output = self.encoder(src, tgt)

        output = output.transpose(0, 1)
        logits = self.lm_head(output)
        return {"logits": logits}

# 3. Dataset, DataLoader, Tokenizer
class ProofDataset(Dataset):
    def __init__(self, tokenizer, proof_pairs, max_length=512):
        self.tokenizer = tokenizer
        self.data = proof_pairs
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        input_text, target_text = self.data[idx]
        input_ids = self.tokenizer.encode(
            input_text, add_special_tokens=True,
            return_tensors="pt", max_length=self.max_length, truncation=True
        ).squeeze(0)
        target_ids = self.tokenizer.encode(
            target_text, add_special_tokens=True,
            return_tensors="pt", max_length=self.max_length, truncation=True
        ).squeeze(0)
        # Return as tuple of dictionaries (each representing one side of the pair)
        return {"input_ids": input_ids}, {"input_ids": target_ids}

def collate_fn(batch):
    # Since __getitem__ returns a tuple (input_dict, target_dict), we unpack here:
    inputs, targets = zip(*batch)
    input_ids = [item["input_ids"] for item in inputs]
    target_ids = [item["input_ids"] for item in targets]
    input_ids = nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=0)
    target_ids = nn.utils.rnn.pad_sequence(target_ids, batch_first=True, padding_value=0)
    return {"input_ids": input_ids, "target_ids": target_ids}

# 4. Greedy Decoding Function
def greedy_decode(logits):
    # Select the token with the highest probability at each time step.
    token_ids = torch.argmax(logits, dim=-1)
    # Return the tokens for the first sequence in the batch.
    return token_ids[0].cpu().tolist()

# 5. Evaluation Function with Additional Metrics (ROUGE, METEOR, and Match Accuracy)
def evaluate_accuracy(model, dataloader, tokenizer, device):
    model.eval()
    correct = 0
    total = 0
    bleu_scores = []
    meteor_scores = []
    rouge_scorer_obj = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
    rouge_scorer_values = []
    match_accuracies = []  # List to store match accuracies for each proof

    with torch.no_grad():
        for batch_idx, batch in tqdm(enumerate(dataloader), desc="Evaluating"):
            input_ids = batch["input_ids"].to(device)
            target_ids = batch["target_ids"].to(device)

            outputs = model(input_ids=input_ids, decoder_input_ids=target_ids)
            logits = outputs["logits"]

            # Use greedy decoding for evaluation on the first example of the batch.
            decoded_tokens = greedy_decode(logits)
            generated_proof = tokenizer.decode(decoded_tokens, skip_special_tokens=True)
            actual_proof = tokenizer.decode(target_ids[0], skip_special_tokens=True)

            # Print proofs and calculate match accuracy.
            print(f"Generated Proof: {generated_proof}")
            print(f"Actual Proof:    {actual_proof}")
            match_acc = get_match_accuracy(generated_proof, actual_proof)
            match_accuracies.append(match_acc)
            print(f"Match Accuracy:  {match_acc:.2f}%")
            print("-" * 80)

            # Exact match evaluation (full string exact match).
            if generated_proof == actual_proof:
                correct += 1

            # Compute BLEU score for semantic similarity.
            bleu_score = sentence_bleu([actual_proof.split()], generated_proof.split())
            bleu_scores.append(bleu_score)

            # Compute METEOR score (using tokenized strings).
            meteor_scores.append(meteor_score([actual_proof.split()], generated_proof.split()))

            # Compute ROUGE score (sum of f-measures from rouge1 and rougeL).
            rouge_score_dict = rouge_scorer_obj.score(actual_proof, generated_proof)
            rouge_total = rouge_score_dict['rouge1'].fmeasure + rouge_score_dict['rougeL'].fmeasure
            rouge_scorer_values.append(rouge_total)

            total += 1

    if total == 0:
        print("No examples evaluated in the test set. Check your test data slice.")
        return

    accuracy = (correct / total * 100) if total > 0 else 0
    avg_bleu = sum(bleu_scores) / len(bleu_scores) if bleu_scores else 0
    avg_meteor = sum(meteor_scores) / len(meteor_scores) if meteor_scores else 0
    avg_rouge = sum(rouge_scorer_values) / len(rouge_scorer_values) if rouge_scorer_values else 0
    avg_match_acc = sum(match_accuracies) / len(match_accuracies) if match_accuracies else 0

    print(f"\nExact Match Accuracy: {accuracy:.2f}%")
    print(f"Average BLEU Score:   {avg_bleu:.4f}")
    print(f"Average METEOR Score: {avg_meteor:.4f}")
    print(f"Average ROUGE Score:  {avg_rouge:.4f}")
    print(f"Average Match Accuracy: {avg_match_acc:.2f}%")

# 6. Initialize Tokenizer, Model, Optimizer, and set Device
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

config = ProofTransformerConfig()
model = ProofTransformer(config)
model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5)
epochs = 28
reward_weight = 0.1

# Load dataset from Parquet file
proof_data = load_dataset("combined_dataset.parquet")

# 7. Split dataset into training and testing examples.
# Correcting the slice for test data: using indices 2000 to 2400 for a 400-example test set.
train_data = proof_data[:2000]
test_data = proof_data[2000:2400]

train_dataset = ProofDataset(tokenizer, train_data, max_length=512)
test_dataset = ProofDataset(tokenizer, test_data, max_length=512)

train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn)

# 8. Training Loop with Gradient Accumulation
accumulation_steps = 4  # Accumulate gradients over 4 steps (simulates a larger effective batch size)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    progress_bar = tqdm(enumerate(train_dataloader), desc=f"Epoch {epoch+1}/{epochs}", total=len(train_dataloader))

    # Zero gradients outside the loop to begin accumulation.
    optimizer.zero_grad()

    for batch_idx, batch in progress_bar:
        input_ids = batch["input_ids"].to(device)
        target_ids = batch["target_ids"].to(device)

        # Forward pass.
        outputs = model(input_ids=input_ids, decoder_input_ids=target_ids)
        logits = outputs["logits"]

        # Compute cross-entropy loss. Flatten the tensors so that we can compare predictions with targets.
        ce_loss = loss_fn(logits.view(-1, config.vocab_size), target_ids.view(-1))
        ce_loss.backward()

        # Gradient accumulation step.
        if (batch_idx + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        running_loss += ce_loss.item()
        progress_bar.set_postfix({"loss": ce_loss.item()})

    # If there are leftover gradients, update once more.
    if (batch_idx + 1) % accumulation_steps != 0:
        optimizer.step()
        optimizer.zero_grad()

    avg_loss = running_loss / len(train_dataloader)
    print(f"Epoch {epoch+1}/{epochs} completed. Avg Loss: {avg_loss:.4f}")

# 9. Evaluate the model on the test set
import nltk
nltk.download('wordnet')
nltk.download('omw-1.4')
evaluate_accuracy(model, test_dataloader, tokenizer, device)


Using device: cuda




Loaded 100000 examples from combined_dataset.parquet.


Epoch 1/28: 100%|██████████| 250/250 [01:59<00:00,  2.10it/s, loss=3.4]


Epoch 1/28 completed. Avg Loss: 4.9796


Epoch 2/28: 100%|██████████| 250/250 [01:57<00:00,  2.13it/s, loss=1.13]


Epoch 2/28 completed. Avg Loss: 2.0221


Epoch 3/28: 100%|██████████| 250/250 [01:59<00:00,  2.10it/s, loss=0.791]


Epoch 3/28 completed. Avg Loss: 1.0031


Epoch 4/28: 100%|██████████| 250/250 [01:58<00:00,  2.11it/s, loss=0.748]


Epoch 4/28 completed. Avg Loss: 0.6626


Epoch 5/28: 100%|██████████| 250/250 [01:57<00:00,  2.13it/s, loss=0.527]


Epoch 5/28 completed. Avg Loss: 0.4761


Epoch 6/28: 100%|██████████| 250/250 [01:57<00:00,  2.12it/s, loss=0.202]


Epoch 6/28 completed. Avg Loss: 0.3593


Epoch 7/28: 100%|██████████| 250/250 [01:57<00:00,  2.13it/s, loss=0.245]


Epoch 7/28 completed. Avg Loss: 0.2875


Epoch 8/28: 100%|██████████| 250/250 [01:56<00:00,  2.15it/s, loss=0.315]


Epoch 8/28 completed. Avg Loss: 0.2290


Epoch 9/28: 100%|██████████| 250/250 [01:59<00:00,  2.09it/s, loss=0.148]


Epoch 9/28 completed. Avg Loss: 0.1857


Epoch 10/28: 100%|██████████| 250/250 [01:57<00:00,  2.13it/s, loss=0.254]


Epoch 10/28 completed. Avg Loss: 0.1607


Epoch 11/28: 100%|██████████| 250/250 [01:58<00:00,  2.11it/s, loss=0.135]


Epoch 11/28 completed. Avg Loss: 0.1336


Epoch 12/28: 100%|██████████| 250/250 [01:56<00:00,  2.14it/s, loss=0.0758]


Epoch 12/28 completed. Avg Loss: 0.1121


Epoch 13/28: 100%|██████████| 250/250 [01:55<00:00,  2.16it/s, loss=0.178]


Epoch 13/28 completed. Avg Loss: 0.1010


Epoch 14/28: 100%|██████████| 250/250 [01:56<00:00,  2.14it/s, loss=0.0645]


Epoch 14/28 completed. Avg Loss: 0.0864


Epoch 15/28: 100%|██████████| 250/250 [01:56<00:00,  2.14it/s, loss=0.0994]


Epoch 15/28 completed. Avg Loss: 0.0763


Epoch 16/28: 100%|██████████| 250/250 [01:57<00:00,  2.12it/s, loss=0.0961]


Epoch 16/28 completed. Avg Loss: 0.0675


Epoch 17/28: 100%|██████████| 250/250 [01:56<00:00,  2.14it/s, loss=0.0424]


Epoch 17/28 completed. Avg Loss: 0.0610


Epoch 18/28: 100%|██████████| 250/250 [01:58<00:00,  2.12it/s, loss=0.0478]


Epoch 18/28 completed. Avg Loss: 0.0540


Epoch 19/28: 100%|██████████| 250/250 [01:56<00:00,  2.15it/s, loss=0.0186]


Epoch 19/28 completed. Avg Loss: 0.0490


Epoch 20/28: 100%|██████████| 250/250 [01:56<00:00,  2.14it/s, loss=0.0365]


Epoch 20/28 completed. Avg Loss: 0.0446


Epoch 21/28: 100%|██████████| 250/250 [01:55<00:00,  2.16it/s, loss=0.011]


Epoch 21/28 completed. Avg Loss: 0.0408


Epoch 22/28: 100%|██████████| 250/250 [01:57<00:00,  2.12it/s, loss=0.0442]


Epoch 22/28 completed. Avg Loss: 0.0366


Epoch 23/28: 100%|██████████| 250/250 [01:56<00:00,  2.15it/s, loss=0.0599]


Epoch 23/28 completed. Avg Loss: 0.0331


Epoch 24/28: 100%|██████████| 250/250 [01:57<00:00,  2.13it/s, loss=0.0468]


Epoch 24/28 completed. Avg Loss: 0.0297


Epoch 25/28: 100%|██████████| 250/250 [01:58<00:00,  2.11it/s, loss=0.0103]


Epoch 25/28 completed. Avg Loss: 0.0279


Epoch 26/28: 100%|██████████| 250/250 [01:57<00:00,  2.14it/s, loss=0.0472]


Epoch 26/28 completed. Avg Loss: 0.0254


Epoch 27/28: 100%|██████████| 250/250 [01:57<00:00,  2.13it/s, loss=0.0293]


Epoch 27/28 completed. Avg Loss: 0.0239


Epoch 28/28: 100%|██████████| 250/250 [01:57<00:00,  2.13it/s, loss=0.00993]
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


Epoch 28/28 completed. Avg Loss: 0.0219


Evaluating: 3it [00:00, 14.38it/s]

Generated Proof: lemma meet _ le x y z : z ≤ x → z ≤ y → z ≤ x y.
Actual Proof:    lemma meet _ le x y z : z ≤ x → z ≤ y → z ≤ x y.
Match Accuracy:  100.00%
--------------------------------------------------------------------------------
Generated Proof: lemma meet _ sl _ mor _ reflecting ` {! injective f } : orderreflecting f.
Actual Proof:    lemma meet _ sl _ mor _ reflecting ` {! injective f } : orderreflecting f.
Match Accuracy:  100.00%
--------------------------------------------------------------------------------
Generated Proof: lemma lt _ flip x y : x < y → ¬y < x.
Actual Proof:    lemma lt _ flip x y : x < y → ¬y < x.
Match Accuracy:  100.00%
--------------------------------------------------------------------------------
Generated Proof: lemma not _ lt _ apart _ lt _ flip x y : ¬x < y → x y → y < x.
Actual Proof:    lemma not _ lt _ apart _ lt _ flip x y : ¬x < y → x y → y < x.
Match Accuracy:  100.00%
-----------------------------------------------------------------------

Evaluating: 5it [00:00, 14.64it/s]

Generated Proof: lemma le _ not _ lt _ flip x y : y ≤ x → ¬x < y.
Actual Proof:    lemma le _ not _ lt _ flip x y : y ≤ x → ¬x < y.
Match Accuracy:  100.00%
--------------------------------------------------------------------------------
Generated Proof: lemma not _ le _ lt _ flip ` {! lapart a } ` { x y, decision ( x = y ) } x y : ¬y ≤ x → x < y.
Actual Proof:    lemma not _ le _ lt _ flip ` {! trivialapart a } ` { x y, decision ( x = y ) } x y : ¬y ≤ x → x < y.
Match Accuracy:  97.06%
--------------------------------------------------------------------------------


Evaluating: 9it [00:03,  2.00it/s]

Generated Proof: lemma compose _ le x y z : 0 ≤ z → y = x + z → x ≤ y.
Actual Proof:    lemma compose _ le x y z : 0 ≤ z → y = x + z → x ≤ y.
Match Accuracy:  100.00%
--------------------------------------------------------------------------------
Generated Proof: lemma nonneg _ nonpos _ mult x y : 0 ≤ x → y ≤ 0 → x * y ≤ 0.
Actual Proof:    lemma nonneg _ nonpos _ mult x y : 0 ≤ x → y ≤ 0 → x * y ≤ 0.
Match Accuracy:  100.00%
--------------------------------------------------------------------------------
Generated Proof: lemma decompose _ lt { x y } : x < y → z, 0 < z ∧ y = x + z.
Actual Proof:    lemma decompose _ lt { x y } : x < y → z, 0 < z ∧ y = x + z.
Match Accuracy:  100.00%
--------------------------------------------------------------------------------


Evaluating: 11it [00:04,  2.77it/s]

Generated Proof: lemma pos _ neg _ mult x y : 0 < x → y < 0 → x * y < 0.
Actual Proof:    lemma pos _ neg _ mult x y : 0 < x → y < 0 → x * y < 0.
Match Accuracy:  100.00%
--------------------------------------------------------------------------------
Generated Proof: lemma lt _ 1 _ 4 : 1 < 4.
Actual Proof:    lemma lt _ 1 _ 4 : 1 < 4.
Match Accuracy:  100.00%
--------------------------------------------------------------------------------
Generated Proof: lemma pos _ plus _ le _ lt _ compat _ r x y z : 0 < z → x ≤ y → x < y + z.
Actual Proof:    lemma pos _ plus _ le _ lt _ compat _ r x y z : 0 < z → x ≤ y → x < y + z.
Match Accuracy:  100.00%
--------------------------------------------------------------------------------


Evaluating: 13it [00:04,  3.75it/s]

Generated Proof: lemma le _ 2 _ 3 : 2 ≤ 3.
Actual Proof:    lemma le _ 2 _ 3 : 2 ≤ 3.
Match Accuracy:  100.00%
--------------------------------------------------------------------------------
Generated Proof: lemma projected _ srorder ` { semiring r2 } ` { r2le : le r2 } ( f : r2 → r1 ) ` {! semiring _ morphism f } ` {! injective f } : ( x y, x ≤ y ↔ f x ≤ f y ) → ( x y : r2, x ≤ y → z, y = x + z ) → semiringorder r2le.
Actual Proof:    lemma projected _ srorder ` { semiring r2 } ` { r2le : le r2 } ( f : r2 → r1 ) ` {! semiring _ morphism f } ` {! injective f } : ( x y, x ≤ y ↔ f x ≤ f y ) → ( x y : r2, x ≤ y → z, y = x + z ) → semiringorder r2le.
Match Accuracy:  100.00%
--------------------------------------------------------------------------------


Evaluating: 16it [00:04,  4.74it/s]

Generated Proof: lemma preserves _ lt _ 1 ` {! strictlyorderpreserving f } x : x < 1 → f x < 1.
Actual Proof:    lemma preserves _ lt _ 1 ` {! strictlyorderpreserving f } x : x < 1 → f x < 1.
Match Accuracy:  100.00%
--------------------------------------------------------------------------------
Generated Proof: lemma between _ to _ ring n : - f n ≤ f n.
Actual Proof:    lemma between _ to _ ring n : - f n ≤ f n.
Match Accuracy:  100.00%
--------------------------------------------------------------------------------


Evaluating: 17it [00:04,  5.32it/s]

Generated Proof: lemma lt _ iff _ plus _ 1 _ le x y : x < y ↔ x + 1 ≤ y.
Actual Proof:    lemma lt _ iff _ plus _ 1 _ le x y : x < y ↔ x + 1 ≤ y.
Match Accuracy:  100.00%
--------------------------------------------------------------------------------
Generated Proof: lemma un mu a ( r : relation a ) ( x y : a ) r ' ( e : r r x a r ' y ) : r x y.
Actual Proof:    lemma unjm a ( r : relation a ) ( x y : a ) r ' ( e : r r x a r ' y ) : r x y.
Match Accuracy:  6.25%
--------------------------------------------------------------------------------


Evaluating: 20it [00:05,  5.55it/s]

Generated Proof: lemma bool _ fact _ rel _ true ` ( r : horne a ) { dec : x y, decision ( r x y ) } : x y, bool _ fact _ rel r x y ≡ true ↔ r x y.
Actual Proof:    lemma bool _ decide _ rel _ true ` ( r : relation a ) { dec : x y, decision ( r x y ) } : x y, bool _ decide _ rel r x y ≡ true ↔ r x y.
Match Accuracy:  93.18%
--------------------------------------------------------------------------------
Generated Proof: lemma quote _ equality { v } { v : vars v } { v ' } { v ' : vars v ' } ( l r : ) ` {! quote novars l v } ` {! quote v r v ' } : let heap : = ( merge v v ' ) in eval heap ( map _ var : quote ) = eval heap quote → l = r.
Actual Proof:    lemma quote _ equality { v } { v : vars v } { v ' } { v ' : vars v ' } ( l r : value ) ` {! quote novars l v } ` {! quote v r v ' } : let heap : = ( merge v v ' ) in eval heap ( map _ var monkey quote ) = eval heap quote → l = r.
Match Accuracy:  40.00%
--------------------------------------------------------------------------------


Evaluating: 22it [00:05,  5.30it/s]

Generated Proof: remark subsets _ 1 : forall b : t, subset b app _ 1 - > cardinal b = 1 - > { b [ = ] 1 + + empty } + { b [ = ] 2 + + empty } + { b [ = ] 3 + + empty }.
Actual Proof:    remark subsets _ 1 : forall b : t, subset b town _ 1 - > cardinal b = 1 - > { b [ = ] 1 + + empty } + { b [ = ] 2 + + empty } + { b [ = ] 3 + + empty }.
Match Accuracy:  98.15%
--------------------------------------------------------------------------------
Generated Proof: remark₁ _ sym : m n, m n - n m.
Actual Proof:    remark familiarity₁ _ sym : m n, m n ⇒ n m.
Match Accuracy:  0.00%
--------------------------------------------------------------------------------


Evaluating: 23it [00:05,  5.26it/s]

Generated Proof: remark ac %tance _ 2 : b, b le * _ 2 - | b | = 2 -, d mk ( n _ 2 \ b ) ∧ ( b, b mkb - d b ).
Actual Proof:    remark acquintance _ 2 : b, b⊆town _ 2 ⇒ | b | = 2 ⇒, d∈ ( town _ 2 \ b ) ∧ ( b, b∈b ⇒ d b ).
Match Accuracy:  3.12%
--------------------------------------------------------------------------------


Evaluating: 25it [00:06,  5.14it/s]

Generated Proof: lemma neq _ lp _ _ beta _ : _ aux2 : forall ( =xp1 =xp2 : z - > z ), valid _ exp =xp1 - > valid _ exp =xp2 - > forall (phi1phi2 : z - > bool ), forall x, 0 < x - > ( =xp2 ( x ) < =xp1 ( x ) ) % z - > lpp beta =xp1 x < x < = lpp beta =xp1 x + / 2 * ]p beta =xp2 x - > round _ round _ eq beta =xp1 =xp2phi1phi2 x.
Actual Proof:    lemma neq _ midpoint _ beta _ odd _ aux2 : forall ( fexp1 fexp2 : z - > z ), valid _ exp fexp1 - > valid _ exp fexp2 - > forall ( choice1 choice2 : z - > bool ), forall x, 0 < x - > ( fexp2 ( mag x ) < fexp1 ( mag x ) ) % z - > midp beta fexp1 x < x < = midp beta fexp1 x + / 2 * ulp beta fexp2 x - > round _ round _ eq beta fexp1 fexp2 choice1 choice2 x.
Match Accuracy:  4.00%
--------------------------------------------------------------------------------
Generated Proof: theorem round _ round _ mult _ beta _ : _ht : forallphi1phi2, ( branchingn ' < = branchingn ) % z - > ( prec < = prec ' ) % z - > forall x y,ht _1 beta branchingn prec x - >ht _

Evaluating: 26it [00:06,  5.20it/s]

Generated Proof: theorem round _ round _ plus _ beta _ : _ _x : forall i1 i2, ( prec < = prec ' ) % z - > forall x y, _x _1 beta prec x - > _x _1 beta prec y - > round _ round _ eq beta ( _x _ exp prec ) ( _x _ exp prec ' ) i1 i2 ( x + y ).
Actual Proof:    theorem round _ round _ plus _ beta _ odd _ flx : forall choice1 choice2, ( prec < = prec ' ) % z - > forall x y, flx _ format beta prec x - > flx _ format beta prec y - > round _ round _ eq beta ( flx _ exp prec ) ( flx _ exp prec ' ) choice1 choice2 ( x + y ).
Match Accuracy:  38.89%
--------------------------------------------------------------------------------


Evaluating: 28it [00:06,  5.19it/s]

Generated Proof: theorem round _ round _ sqrt _ beta _ : _ _x : forall i1 i2, ( prec < = prec ' ) % z - > forall x, _x _1 beta prec x - > round _ round _ eq beta ( _x _ exp prec ) ( _x _ exp prec ' ) i1 i2 ( sqrt x ).
Actual Proof:    theorem round _ round _ sqrt _ beta _ odd _ flx : forall choice1 choice2, ( prec < = prec ' ) % z - > forall x, flx _ format beta prec x - > round _ round _ eq beta ( flx _ exp prec ) ( flx _ exp prec ' ) choice1 choice2 ( sqrt x ).
Match Accuracy:  40.32%
--------------------------------------------------------------------------------
Generated Proof: lemma round _ round _ div _ beta _ : _ st : forallalxp1alxp2 : z - > z, valid _ expalxp1 - > valid _ expalxp2 - > ( forall ex, (alxp2 ex < =alxp1 ex ) % z ) - > forall x y, y < > 0 - > _ _1 betaalxp1 x - > _ _1 betaalxp1 y - > round betaalxp1 z hma ( round betaalxp2 z hma ( x / y ) ) = round betaalxp1 z hma ( x / y ).
Actual Proof:    lemma round _ round _ div _ beta _ odd _ rna : forall fexp1 fexp2 : z - >

Evaluating: 29it [00:07,  5.40it/s]

Generated Proof: lemma tnd _ minus : forall m m u1 v1 b1 b1 v v2 b2 b2, tnd m m u1 v1 b1 b1 - > tnd m m v v2 b2 b2 - > tnd m m ( u1 - v ) ( v1 - v2 ) ( fpl x b1 b2 ) ( fpl x b1 b2 ).
Actual Proof:    lemma hombnd _ minus : forall m m u1 v1 b1 b1 u2 v2 b2 b2, hombnd m m u1 v1 b1 b1 - > hombnd m m u2 v2 b2 b2 - > hombnd m m ( u1 - u2 ) ( v1 - v2 ) ( fplus b1 b2 ) ( fplus b1 b2 ).
Match Accuracy:  71.93%
--------------------------------------------------------------------------------


Evaluating: 31it [00:07,  5.22it/s]

Generated Proof: lemma t bnd _ mul : forall { m m1 u1 v1 b1 b1 m2 v v2 b2 b2 }, t bnd ' m m1 u1 v1 b1 b1 - > t bnd ' m m2 v v2 b2 b2 - > t bnd ' m ( m1 * m2 ) ( u1 * v ) ( v1 * v2 ) ( mult _fr b1 b1 b2 b2 ) ( fmult b1 b2 ).
Actual Proof:    lemma hombnd _ mul : forall { m m1 u1 v1 b1 b1 m2 u2 v2 b2 b2 }, hombnd ' m m1 u1 v1 b1 b1 - > hombnd ' m m2 u2 v2 b2 b2 - > hombnd ' m ( m1 * m2 ) ( u1 * u2 ) ( v1 * v2 ) ( mult _ err b1 b1 b2 b2 ) ( fmult b1 b2 ).
Match Accuracy:  4.23%
--------------------------------------------------------------------------------
Generated Proof: lemma i _ le _ l : forall a b c, 0 < a - > b < = c / a - > a * b < = c.
Actual Proof:    lemma rdiv _ le _ l : forall a b c, 0 < a - > b < = c / a - > a * b < = c.
Match Accuracy:  96.67%
--------------------------------------------------------------------------------


Evaluating: 33it [00:07,  5.48it/s]

Generated Proof: theorem mult _ correct : forall x y : _ beta, round beta clxp →d ( f2r x * f2r y ) = f2r ( mult x y ).
Actual Proof:    theorem mult _ correct : forall x y : float beta, round beta fexp rnd ( f2r x * f2r y ) = f2r ( mult x y ).
Match Accuracy:  89.66%
--------------------------------------------------------------------------------
Generated Proof: theorem exp _ correct : forall x : r, _ _1 _dix2 ( _t _ exp ( - f4 ) aux ) x - > -6 < = x < = _ - > _. ( ( _ exp x - exp x ) / exp x ) < = 1 * pow2 ( - 51 ).
Actual Proof:    theorem exp _ correct : forall x : r, generic _ format radix2 ( flt _ exp ( - 1074 ) 53 ) x - > - 746 < = x < = 710 - > rabs ( ( cw _ exp x - exp x ) / exp x ) < = 1 * pow2 ( - 51 ).
Match Accuracy:  18.33%
--------------------------------------------------------------------------------


Evaluating: 35it [00:08,  5.68it/s]

Generated Proof: lemma nat _ : 0 < = m.
Actual Proof:    lemma mpos : 0 < = m.
Match Accuracy:  14.29%
--------------------------------------------------------------------------------
Generated Proof: lemmafr _ in _ : forall x,fr ( round _ _x x ) x eps.
Actual Proof:    lemma err _ init : forall x, err ( round _ flx x ) x eps.
Match Accuracy:  12.50%
--------------------------------------------------------------------------------


Evaluating: 37it [00:08,  6.13it/s]

Generated Proof: lemma m _ correct :fr m e _ m ( _ / 2 * eps + eva * eps * eps ).
Actual Proof:    lemma m _ correct : err m e _ m ( 15 / 2 * eps + 26 * eps * eps ).
Match Accuracy:  17.39%
--------------------------------------------------------------------------------
Generated Proof: lemma _t _ pos _ is _ pos : forall x, 0 < = x - > 0 < = round _ _t x.
Actual Proof:    lemma flt _ pos _ is _ pos : forall x, 0 < = x - > 0 < = round _ flt x.
Match Accuracy:  91.67%
--------------------------------------------------------------------------------


Evaluating: 39it [00:08,  5.75it/s]

Generated Proof: lemma t4 _ nt _ : t4 = c - ( a - b ).
Actual Proof:    lemma t4 _ exact _ : t4 = c - ( a - b ).
Match Accuracy:  93.33%
--------------------------------------------------------------------------------
Generated Proof: lemmafr _ mult _ : forall x1 y1 e1 x2 y2 e2,1 x1 - >1 x2 - >fr x1 y1 e1 - >fr x2 y2 e2 - > ( bp ( sn + prec - 1 ) < _. ( round _ annt ( x1 * x2 ) ) ) - >fr ( round _ annt ( x1 * x2 ) ) ( y1 * y2 ) ( eps + ( 1 + eps ) * ( e1 + e2 + e1 * e2 ) ).
Actual Proof:    lemma err _ mult _ : forall x1 y1 e1 x2 y2 e2, format x1 - > format x2 - > err x1 y1 e1 - > err x2 y2 e2 - > ( bpow ( emin + prec - 1 ) < rabs ( round _ flt ( x1 * x2 ) ) ) - > err ( round _ flt ( x1 * x2 ) ) ( y1 * y2 ) ( eps + ( 1 + eps ) * ( e1 + e2 + e1 * e2 ) ).
Match Accuracy:  5.43%
--------------------------------------------------------------------------------


Evaluating: 41it [00:09,  5.64it/s]

Generated Proof: lemma delta _ correct _ : / 4 * bp ( z l ( ( ( izr ( sn + prec - 1 ) ) / 2 ) ) < delta - > ( _. ( delta - e _ delta ) < = ( / 4 * eps + fun * eps * eps ) * e _ delta ).
Actual Proof:    lemma delta _ correct _ : / 4 * bpow ( zceil ( ( izr ( emin + prec - 1 ) ) / 2 ) ) < delta - > ( rabs ( delta - e _ delta ) < = ( 23 / 4 * eps + 38 * eps * eps ) * e _ delta ).
Match Accuracy:  21.67%
--------------------------------------------------------------------------------
Generated Proof: lemma ]p _ sqr _ ]p _ lt : forall u, 0 < u - > ( u < sqrt ( izr ( _dix _ val beta ) ) * bp prove ( beta u - 1 ) ) - > ]p _ flx ( u * u ) + ]p _ flx u * ]p _ flx u / 2 < 2 * u * ]p _ flx u.
Actual Proof:    lemma ulp _ sqr _ ulp _ lt : forall u, 0 < u - > ( u < sqrt ( izr ( radix _ val beta ) ) * bpow ( mag beta u - 1 ) ) - > ulp _ flx ( u * u ) + ulp _ flx u * ulp _ flx u / 2 < 2 * u * ulp _ flx u.
Match Accuracy:  85.71%
------------------------------------------------------------------------

Evaluating: 43it [00:09,  5.58it/s]

Generated Proof: theorem round _ =x _ sqr _ sqrt _ aux1 _ simpl : ( / 2 * bp ( beta x ) + bp ( 2 + beta x - prec ) < = ( 2 * izr k + 1 ) * x ) - > xk < = z.
Actual Proof:    theorem round _ flx _ sqr _ sqrt _ aux1 _ simpl : ( / 2 * bpow ( mag beta x ) + bpow ( 2 + mag beta x - prec ) < = ( 2 * izr k + 1 ) * x ) - > xk < = z.
Match Accuracy:  32.08%
--------------------------------------------------------------------------------
Generated Proof: theorem round _ _x _ sqr _ sqrt _ w : forall x,1 x - > ( beta < = 4 ) % z - > round _ _x2 ( sqrt ( round _ _x1 ( x * x ) ) ) = _. x.
Actual Proof:    theorem round _ flx _ sqr _ sqrt _ exact : forall x, format x - > ( beta < = 4 ) % z - > round _ flx2 ( sqrt ( round _ flx1 ( x * x ) ) ) = rabs x.
Match Accuracy:  26.09%
--------------------------------------------------------------------------------


Evaluating: 45it [00:09,  6.20it/s]

Generated Proof: theorem sqrt _ sqr : forall x y : r,1 x - > - 1 < = round _ annx1 ( x / round _ annx2 ( r _ sqrt. sqrt ( round _ annx3 ( round _ annx4 ( x * x ) + round _ annx5 ( y * y ) ) ) ) ) < = 1.
Actual Proof:    theorem sqrt _ sqr : forall x y : r, format x - > - 1 < = round _ flx1 ( x / round _ flx2 ( r _ sqrt. sqrt ( round _ flx3 ( round _ flx4 ( x * x ) + round _ flx5 ( y * y ) ) ) ) ) < = 1.
Match Accuracy:  21.31%
--------------------------------------------------------------------------------
Generated Proof: lemma round _ plus _ small _ id : forall f h,1 f - > ( bp [ ( prec + len ) < = _. f ) - > _. h < = / 4 * ]p _ hct f - > round _ hct ( f + h ) = f.
Actual Proof:    lemma round _ plus _ small _ id : forall f h, format f - > ( bpow ( prec + emin ) < = rabs f ) - > rabs h < = / 4 * ulp _ flt f - > round _ flt ( f + h ) = f.
Match Accuracy:  75.47%
--------------------------------------------------------------------------------


Evaluating: 47it [00:10,  6.05it/s]

Generated Proof: lemma avg _set _ no _mma var : ( bp sn ) < = _. a - > av < > 0.
Actual Proof:    lemma avg _ naive _ no _ underflow : ( bpow emin ) < = rabs a - > av < > 0.
Match Accuracy:  8.70%
--------------------------------------------------------------------------------
Generated Proof: lemma avg _ half _ sub _ between : lein x y < = av < = le _ x y.
Actual Proof:    lemma avg _ half _ sub _ between : rmin x y < = av < = rmax x y.
Match Accuracy:  80.00%
--------------------------------------------------------------------------------


Evaluating: 49it [00:10,  6.14it/s]

Generated Proof: lemma avg _ half _ sub _ correct _ aux2 : forall u v,1 u - >1 v - > u < = v - > ( 0 < = u / \ 0 < = v ) \ / ( u < = 0 / \ v < = 0 ) - > _ ( avg _ half _ sub u v - ( ( u + v ) / 2 ) ) < = 3 / 2 * ]p _ flt ( ( u + v ) / 2 ).
Actual Proof:    lemma avg _ half _ sub _ correct _ aux2 : forall u v, format u - > format v - > u < = v - > ( 0 < = u / \ 0 < = v ) \ / ( u < = 0 / \ v < = 0 ) - > rabs ( avg _ half _ sub u v - ( ( u + v ) / 2 ) ) < = 3 / 2 * ulp _ flt ( ( u + v ) / 2 ).
Match Accuracy:  14.89%
--------------------------------------------------------------------------------
Generated Proof: lemma ( _ zero : a = 0 - > av = 0.
Actual Proof:    lemma average _ zero : a = 0 - > av = 0.
Match Accuracy:  92.31%
--------------------------------------------------------------------------------


Evaluating: 50it [00:10,  4.68it/s]

Generated Proof: lemma round _ ne _ pt _ pos : forall x, ( 0 < x ) % r - >d _ ne _ pt x ( round beta =xp z ye x ).
Actual Proof:    lemma round _ ne _ pt _ pos : forall x, ( 0 < x ) % r - > rnd _ ne _ pt x ( round beta fexp zneareste x ).
Match Accuracy:  63.64%
--------------------------------------------------------------------------------

Exact Match Accuracy: 32.00%
Average BLEU Score:   0.7838
Average METEOR Score: 0.8954
Average ROUGE Score:  1.7147
Average Match Accuracy: 61.70%





In [None]:
!pip install transformers rouge_score nltk tqdm

Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: rouge_score
  Building wheel for rouge_score (setup.py) ... [?25l[?25hdone
  Created wheel for rouge_score: filename=rouge_score-0.1.2-py3-none-any.whl size=24934 sha256=13ca2d165d7f98a2754e30fcdc556f09cdba1a15f8f5903e3b23cfd4a7663fca
  Stored in directory: /root/.cache/pip/wheels/1e/19/43/8a442dc83660ca25e163e1bd1f89919284ab0d0c1475475148
Successfully built rouge_score
Installing collected packages: rouge_score
Successfully installed rouge_score-0.1.2


In [None]:
from google.colab import files
uploaded = files.upload()

Saving combined_dataset.parquet to combined_dataset.parquet


In [None]:
!apt-get install coq -y

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following additional packages will be installed:
  ledit libcoq-core-ocaml libcoq-stdlib libfindlib-ocaml libfindlib-ocaml-dev
  libzarith-ocaml ocaml ocaml-base ocaml-compiler-libs ocaml-findlib
  ocaml-interp ocaml-man ocaml-nox
Suggested packages:
  coqide | proofgeneral libcoq-core-ocaml-dev why coq-doc ocaml-doc
  elpa-tuareg camlp4
The following NEW packages will be installed:
  coq ledit libcoq-core-ocaml libcoq-stdlib libfindlib-ocaml
  libfindlib-ocaml-dev libzarith-ocaml ocaml ocaml-base ocaml-compiler-libs
  ocaml-findlib ocaml-interp ocaml-man ocaml-nox
0 upgraded, 14 newly installed, 0 to remove and 30 not upgraded.
Need to get 281 MB of archives.
After this operation, 986 MB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/universe amd64 libcoq-stdlib amd64 8.15.0+dfsg-2 [24.7 MB]
Get:2 http://archive.ubuntu.com/ubuntu jammy/universe amd6

In [None]:
!pip install torch_geometric

