In [1]:
"""
Implements the V3 of the data setup, using the same seed and splits as v1
(src/notebooks/230209_completion_measure/230209_data_setup.ipynb) except
start sequent is not included."""

from tqdm import tqdm

from coprover.lemmaret.vectorizers import *

In [2]:
sbv = SBertVectorizer()

In [3]:
sbv.vectorize("<ANT> <CONS> s-formula apply constant apply constant constant constant <HID>").shape

(768,)

In [4]:
tokenizer = sbv.model.tokenizer

In [5]:
from transformers import (
    T5ForConditionalGeneration,
    MT5ForConditionalGeneration,
    ByT5Tokenizer,
    PreTrainedTokenizer,
    T5TokenizerFast as T5Tokenizer,
    MT5TokenizerFast as MT5Tokenizer,
)

In [6]:
tokenizer.encode("<ANT> <CONS> s-formula apply constant apply constant constant constant <HID>")

[4,
 38,
 423,
 62,
 40,
 1363,
 45,
 57,
 56,
 61,
 40,
 290,
 23,
 857,
 339,
 1983,
 390,
 418,
 1983,
 390,
 418,
 390,
 418,
 390,
 418,
 1363,
 1385,
 46,
 40,
 5]

In [7]:
# Now msuter the data
from pathlib import Path
import pandas as pd

from coprover import DATA_ROOT, RSC_ROOT

cmdpred_df = pd.read_csv(Path(RSC_ROOT, "pvs_cmd_pred", "data","cmdpred_N3.pvslib.tsv.gz"), sep='\t', names=['sequent', 'command', 'cmd_history', 'uri'])
                         

In [24]:
cmdpred_df[0:22]

Unnamed: 0,sequent,command,cmd_history,uri,proofname
0,<ANT> <CONS> s-formula forall ['variable'] ['v...,skosimp*,"NOOP,NOOP,NOOP",vect2_cont_comp-proofs/comp_rr_vr_cont#0,vect2_cont_comp-proofs/comp_rr_vr_cont
1,<ANT> <CONS> s-formula apply constant apply co...,typepred,"NOOP,NOOP,skosimp*",vect2_cont_comp-proofs/comp_rr_vr_cont#1,vect2_cont_comp-proofs/comp_rr_vr_cont
2,<ANT> s-formula apply constant apply constant ...,typepred,"NOOP,skosimp*,typepred",vect2_cont_comp-proofs/comp_rr_vr_cont#2,vect2_cont_comp-proofs/comp_rr_vr_cont
3,<ANT> s-formula apply constant apply constant ...,expand,"skosimp*,typepred,typepred",vect2_cont_comp-proofs/comp_rr_vr_cont#3,vect2_cont_comp-proofs/comp_rr_vr_cont
4,<ANT> s-formula apply constant apply constant ...,expand,"typepred,typepred,expand",vect2_cont_comp-proofs/comp_rr_vr_cont#4,vect2_cont_comp-proofs/comp_rr_vr_cont
5,<ANT> s-formula apply constant forall ['variab...,skosimp*,"typepred,expand,expand",vect2_cont_comp-proofs/comp_rr_vr_cont#5,vect2_cont_comp-proofs/comp_rr_vr_cont
6,<ANT> s-formula apply constant forall ['variab...,expand,"expand,expand,skosimp*",vect2_cont_comp-proofs/comp_rr_vr_cont#6,vect2_cont_comp-proofs/comp_rr_vr_cont
7,<ANT> s-formula apply constant forall ['variab...,inst,"expand,skosimp*,expand",vect2_cont_comp-proofs/comp_rr_vr_cont#7,vect2_cont_comp-proofs/comp_rr_vr_cont
8,<ANT> s-formula apply constant apply constant ...,inst,"skosimp*,expand,inst",vect2_cont_comp-proofs/comp_rr_vr_cont#8,vect2_cont_comp-proofs/comp_rr_vr_cont
9,<ANT> s-formula apply constant apply constant ...,expand,"expand,inst,inst",vect2_cont_comp-proofs/comp_rr_vr_cont#9,vect2_cont_comp-proofs/comp_rr_vr_cont


In [9]:
start_state = None
last_state = None
last_proofname = None

def get_proofname(uri):
    return uri.split('#', 1)[0]


proofnames = [get_proofname(row.uri) for idx, row in cmdpred_df.iterrows()]
    

In [10]:
cmdpred_df['proofname'] = proofnames

In [11]:
grp_obj = cmdpred_df.groupby('proofname')

In [12]:
proofnames = list(grp_obj.groups.keys())
print(len(proofnames))

7656


In [13]:
PROOFNAME = "proofname"
STATE = "source_text"
LABEL = "target_text"
POS = "pos"
NEG = "neg"
CMD_HISTORY = "cmd_history"

class MTuple:
    def __init__(self, proofname, start_row, end_row, label):
        self.proofname = proofname
        self.label = label
        self.start_row, self.end_row = start_row, end_row
        self.cmd_history = self.end_row.cmd_history
        assert str(self.start_row) != str(self.end_row)

    def __str__(self, str):
        return self.proofname
    
    def _statestr(self):
        cmdhist_str = self.end_row.cmd_history.replace(",", " ")
        return "{} {}".format(cmdhist_str, self.end_row.sequent)  # First naive formulation
    
    def as_row(self):
        return {
            PROOFNAME: self.proofname,
            STATE: self._statestr(),
            CMD_HISTORY: self.end_row.cmd_history.replace(",", " "),
            LABEL: self.label
        }

In [14]:
from random import Random

rnd = Random()
rnd.seed(1337)

pos_mtuples = []
neg_mtuples = []
hard_neg_mtuples = []

for proofname in tqdm(proofnames):
    rows = grp_obj.get_group(proofname)
    if len(rows) >= 3:
        start_row = rows.iloc[0]
        end_row = rows.iloc[len(rows)-1]
        hard_neg_row = rows.iloc[len(rows)-2]
        random_neg_row = rows.iloc[rnd.randint(1, len(rows) - 1)]
        pos_mtuples.append(MTuple(proofname, start_row, end_row, POS))
        neg_mtuples.append(MTuple(proofname, start_row, random_neg_row, NEG))
        hard_neg_mtuples.append(MTuple(proofname, start_row, hard_neg_row, NEG))

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7656/7656 [00:15<00:00, 483.84it/s]


In [15]:
idx=0
print("Start")
print(pos_mtuples[idx].start_row.sequent, pos_mtuples[idx].start_row.cmd_history)
print("\n\nPositive")
print(pos_mtuples[idx].end_row.sequent, pos_mtuples[idx].end_row.cmd_history)

print("\n\nNegative")
print(neg_mtuples[idx].end_row.sequent, neg_mtuples[idx].end_row.cmd_history)

print("\n\nHard Negative")
print(hard_neg_mtuples[idx].end_row.sequent, hard_neg_mtuples[idx].end_row.cmd_history)


Start
<ANT> <CONS> s-formula forall ['variable'] ['variable'] ['variable'] apply constant apply constant apply constant ['variable'] apply constant ['variable'] apply constant ['variable'] ['variable'] apply constant apply constant ['variable'] ['variable'] apply constant ['variable'] ['variable'] <HID>  NOOP,NOOP,NOOP


Positive
<ANT> s-formula apply constant apply constant type-actual apply constant integer constant apply constant constant apply constant apply constant constant constant apply constant constant constant s-formula apply constant apply constant type-actual apply constant constant constant integer s-formula apply constant forall ['variable'] apply constant apply constant apply constant type-actual ['variable'] integer apply constant type-actual ['variable'] constant apply constant apply constant ['variable'] constant s-formula apply constant apply constant constant integer s-formula apply constant apply constant type-actual apply constant constant constant apply constant

In [16]:
# save as a dataframe and then feed into simple_t5
# setup so only tuples below max token lengths can be used

filtered_rows = []
total = 0
for mt in pos_mtuples + neg_mtuples:
    entry = mt.as_row()
    total += 1
    if len(entry[STATE].split()) <= 1000:
        filtered_rows.append(entry)

print("Filtered size={}/{}".format(len(filtered_rows), total))
inst_df = pd.DataFrame(filtered_rows)
inst_df.to_csv("laststep_pred.v3.csv", header=True)

Filtered size=12403/12434


In [17]:
inst_df

Unnamed: 0,proofname,source_text,cmd_history,target_text
0,.ipynb_checkpoints/Euclids_30-checkpoint,factor expand inst S_<ANT> S_<CONS> S_s-formul...,factor expand inst,pos
1,.ipynb_checkpoints/all_least_bounded-checkpoint,flatten-disjunct flatten instantiate S_<ANT> S...,flatten-disjunct flatten instantiate,pos
2,.ipynb_checkpoints/cartesian_product_n_add_is_...,expand hide typepred S_<ANT> S_<CONS> S_s-form...,expand hide typepred,pos
3,.ipynb_checkpoints/cartesian_product_one_empty...,skolem! skosimp grind S_<ANT> S_<CONS> S_s-for...,skolem! skosimp grind,pos
4,.ipynb_checkpoints/chain_rule_TCC1-checkpoint,lemma instantiate inst? S_<ANT> S_<CONS> S_s-f...,lemma instantiate inst?,pos
...,...,...,...,...
12398,zorn-proofs/tower_intersection,expand* skolem skolem! S_<ANT> S_<CONS> S_s-fo...,expand* skolem skolem!,neg
12399,zorn-proofs/zorn,flatten lemma flatten-disjunct S_<ANT> S_<CONS...,flatten lemma flatten-disjunct,neg
12400,zp_group-proofs/Zn_card,split prop expand S_<ANT> S_<CONS> S_s-formula...,split prop expand,neg
12401,zp_group-proofs/Zn_finite,expand inst expand S_<ANT> S_<CONS> S_s-formul...,expand inst expand,neg


In [18]:
toklens = []
for _, x in inst_df.iterrows(): 
    toklens.append(len(x.source_text.split()))

In [19]:
max(toklens)


999

In [20]:
# See percentages of rows below different values of max token lengths
len([x for x in toklens if x <= 5000])

12403