In [2]:
import argparse
import logging
import os
import pickle
import random
import torch
import json
import numpy as np
from tqdm import tqdm, trange   
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler,TensorDataset  
from gensim.models.doc2vec import Doc2Vec
from tokenize_code import tokenize_code

def to_embedding(code_list, model):
    # use doc2vec model to convert code to embedding
    return [model.infer_vector(tokenize_code(x,'code')) for x in code_list]

class TextDataset(Dataset):
    def __init__(self, cache_file, model):
        self.raw=pickle.load(open(cache_file,'rb'))  
        self.model = model
        self.get_embedding = True
        

    def __len__(self):
        return len(self.raw)

    def __getitem__(self, idx): 
        cell_list, kernel_id = self.raw[idx]
        code_list = []
        lib_list = []
        for code, lib in cell_list:
            code_list.append(code)
            lib_list.append(lib)
        
        return to_embedding(code_list, self.model), lib_list, kernel_id

if __name__ == "__main__":
    doc2vec_path = '../doc2vec/model'
    model = Doc2Vec.load(doc2vec_path + "/notebook-doc2vec-model-apr24.model")


    dataset = TextDataset('./cache_doc2vec.pkl', model)
    print(dataset[597])
 

([array([-1.88238725e-01,  2.25036830e-01,  3.79246950e-01,  3.42675149e-02,
        3.60706411e-02, -5.04446849e-02,  1.39585361e-01, -3.48320127e-01,
       -2.16195449e-01,  7.39025176e-02,  2.73990035e-01, -2.95626279e-02,
        2.33643949e-01, -4.39263552e-01, -1.77739590e-01,  2.56049216e-01,
       -1.19739801e-01, -3.06994338e-02,  9.41356793e-02,  2.27944225e-01,
       -7.48425946e-02,  1.54281676e-01, -3.80889118e-01, -2.69809127e-01,
       -1.26202703e-01,  1.67736977e-01,  2.99075276e-01,  9.41609312e-03,
       -1.62066981e-01,  2.36779034e-01, -1.68161198e-01,  1.55236095e-01,
        1.23640008e-01,  2.48196006e-01,  1.97682306e-01, -2.86018461e-01,
        1.33436741e-02, -1.22170551e-02, -2.53526382e-02,  9.63518471e-02,
       -2.22086728e-01,  9.66295972e-02,  8.46449658e-02, -1.57991141e-01,
        7.82672614e-02,  5.02413213e-02,  6.12152405e-02,  6.22246601e-02,
       -4.23125923e-01,  2.08798498e-01, -4.17008325e-02, -1.99491739e-01,
        5.38521148e-02,

In [3]:
all_tensors = []
all_kernel_ids = []
all_libnames = []
for idx in trange(len(dataset)):
    embeds, libs, kernel_ids = dataset[idx]
    all_tensors.append(embeds)
    all_kernel_ids.append(kernel_ids)
    all_libnames.append(libs)

np.save("./embed_tensors_apr29", np.array(all_tensors, dtype=object))
np.save("./kernel_ids_apr29", np.array(all_kernel_ids, dtype=object))
np.save("./lib_names_apr29", np.array(all_libnames, dtype=object))

100%|██████████| 249315/249315 [2:55:08<00:00, 23.73it/s]   
