In [9]:
from typing import List, Dict

from glob import glob
import pickle

from tqdm import tqdm

import torch
from torch import Tensor
from torch.nn.utils.rnn import pad_sequence

from griffon.coq_dataclasses import Stage1Token, Stage1Statement, Stage1Sample
from griffon.dataset.ct_coq_dataset import CTCoqDataset

In [10]:
test_files = glob("../data/processed/stage1/test/**/*.pickle", recursive = True)
valid_files = glob("../data/processed/stage1/valid/**/*.pickle", recursive = True)

vocab = pickle.load(open("../models/vocab.pickle", "rb"))

files = test_files + valid_files

In [11]:
total_words = 0
covered_words = 0

def process_tokens(tokens:List[Stage1Token]):
    words = 0
    known_words = 0
    for token in tokens:
        for subtoken in token.subtokens:
            words += 1
            if subtoken in vocab:
                known_words += 1
    return words, known_words

for file in tqdm(files):
    sample:Stage1Sample = pickle.load(open(file, "rb"))
    for hypothesis in sample.hypotheses:
        tw, cw = process_tokens(hypothesis.tokens)
        total_words += tw
        covered_words += cw
    tw, cw = process_tokens(sample.goal.tokens)
    total_words += tw
    covered_words += cw
    tw, cw = process_tokens(sample.lemma_used)
    total_words += tw
    covered_words += cw



100%|██████████| 42714/42714 [01:16<00:00, 554.95it/s]


In [12]:
covered_words / total_words

0.778208040416655

In [6]:
def total_size(t:Tensor):
    return t.element_size() * t.nelement()

bytes = total_size(sample.sequences) + \
        total_size(sample.extended_vocabulary_ids) + \
        total_size(sample.pointer_pad_mask) + \
        total_size(sample.distances_index) + \
        total_size(sample.distances_bins) + \
        total_size(sample.padding_mask) + \
        total_size(sample.lemma)

kb = bytes / 1024
kb

82.951171875

In [4]:
sample

CTCoqSample(hypotheses=[CTCoqStatement(tokens=tensor([[7, 1, 1, 1, 1],
        [5, 4, 3, 1, 1],
        [6, 1, 1, 1, 1],
        [2, 1, 1, 1, 1]]), extended_vocabulary_ids=[7, 5, 4, 3, 6, 2], pointer_pad_mask=tensor([[ True, False, False, False, False],
        [ True,  True,  True, False, False],
        [ True, False, False, False, False],
        [ True, False, False, False, False]]), distances=[(tensor([[ 5, 31, 29, 24],
        [24,  1,  9, 16],
        [29, 21,  7, 22],
        [24, 27, 22,  6]]), tensor([1.0000e+04, 8.7975e-01, 1.0330e+00, 1.1538e+00, 1.2563e+00, 1.3411e+00,
        1.4866e+00, 1.6302e+00, 1.7595e+00, 1.8783e+00, 1.9815e+00, 2.0893e+00,
        2.2147e+00, 2.3878e+00, 2.4491e+00, 2.5427e+00, 2.6295e+00, 2.7803e+00,
        2.9124e+00, 3.0180e+00, 3.1951e+00, 3.2933e+00, 3.3694e+00, 3.4717e+00,
        3.6025e+00, 3.7403e+00, 3.9105e+00, 4.0540e+00, 4.1859e+00, 4.3461e+00,
        4.5866e+00, 4.8835e+00]), 'ppr'), (tensor([[1, 6, 6, 5],
        [6, 1, 3, 4],
    