In [5]:
import griffon
from griffon.coq_dataclasses import *
import pickle

from griffon.preprocessing import Tokenizer, Vocab

import torch
from torch.nn.utils.rnn import pad_sequence


import os
from glob import glob

In [6]:
paths = glob("recommandations/test/*.pickle")

sentences = [pickle.load(open(filename, "rb"))[1] for filename in paths[:1000]]
with open("small_recommandations/train/train.goals", "w") as goal_file:
    with open("small_recommandations/train/train.used", "w") as used_file:
        for sentence in sentences:
            if sentence.used_item.type_text is not None and \
               sentence.goal.type_text:
                goal_file.write(sentence.goal.type_text + "\n")
                used_file.write(sentence.used_item.type_text + "\n")

In [None]:
!ls small_recommandations/train

In [7]:
import math
import io 

# from torchtext
def read_text_iterator(path):
    with io.open(path, encoding="utf8") as f:
        for row in f:
            yield row

class MyIterableDataset(torch.utils.data.IterableDataset):
     def __init__(self, data_root:str):
         super(MyIterableDataset).__init__()
         self.goal_path = os.path.join(data_root, "train.goals")
         self.used_path = os.path.join(data_root, "train.used")

     def __iter__(self):
        goal_iter = read_text_iterator(self.goal_path)
        used_iter = read_text_iterator(self.used_path)
        return zip(goal_iter, used_iter)

ds = MyIterableDataset("small_recommandations/train")



# Single-process loading
#print(list(torch.utils.data.DataLoader(ds, num_workers=0, batch_size=None)))



In [9]:
def list_transforms(*transforms):
    def func(txt_inputs):
        for transform in transforms:
            txt_inputs = [transform(txt_input) for txt_input in txt_inputs]
        return txt_inputs
    return func

vocab:Vocab = pickle.load(open("vocab.pickle", "rb"))
tokenize_transform = Tokenizer()
vocab_transform = vocab.sentence_to_tensor

pipeline = list_transforms(tokenize_transform,
                           vocab_transform)

test_input = ["forall (x y : Carrier (cart E F)) (_ : @Equal (cart E F) x y), @Equal E (proj1 x) (proj1 y) somerandomword", 
             "forall (x y : Carrier (cart E F))"]

pipeline(test_input)

def collate_fn(batch):
    goal_batch, used_batch = [], []
    for goal, used in batch:
        goal_batch.append(pipeline(goal.rstrip("\n")))
        used_batch.append(pipeline(used.rstrip("\n")))
    
    print(goal_batch)
    goal_batch = pad_sequence(goal_batch, vocab.PAD_IDX)
    used_batch = pad_sequence(used_batch, vocab.PAD_IDX)
    return goal_batch, used_batch

itera = torch.utils.data.DataLoader(ds, num_workers=0, batch_size=2, collate_fn=collate_fn)
for _, batch in zip(range(1), itera):
    print(batch)

[[array([5928, 5929]), array([5928,   68, 5929]), array([5928,   67, 5929]), array([5928,   33, 5929]), array([5928,   49, 5929]), array([5928,   11, 5929]), array([5928,   67, 5929]), array([5928,   74, 5929]), array([5928, 5929]), array([5928,   68, 5929]), array([5928,   23, 5929]), array([5928,   33, 5929]), array([5928,   77, 5929]), array([5928, 5929]), array([5928,   11, 5929]), array([5928,   67, 5929]), array([5928,   77, 5929]), array([5928, 5929]), array([5928,   49, 5929]), array([5928,   11, 5929]), array([5928,   82, 5929]), array([5928,   23, 5929]), array([5928,   68, 5929]), array([5928,   49, 5929]), array([5928,   23, 5929]), array([5928,   11, 5929]), array([5928,   77, 5929]), array([5928, 5929]), array([5928,   68, 5929]), array([5928,   67, 5929]), array([5928,   84, 5929]), array([5928,  138, 5929]), array([5928,   67, 5929]), array([5928,   53, 5929]), array([5928,   77, 5929]), array([5928, 5929]), array([5928,   82, 5929]), array([5928,  582, 5929]), array([5

TypeError: expected Tensor as element 0 in argument 0, but got list