In [None]:
#!/usr/bin/python
# -*- coding: utf-8 -*-
import argparse
import logging
import os
import pickle
import random
import torch
import json
import numpy as np
import pandas as pd
from model import BertModel, Generator
from torch.nn import CrossEntropyLoss, MSELoss
from torch.utils.data import DataLoader, Dataset, SequentialSampler, \
    RandomSampler, TensorDataset
import math
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm, trange
from pathlib import Path
import codecs
import json

from utils import CellFeatures, InputFeatures, \
    convert_examples_to_features, parseNotebook, get_notebook_list, \
    get_embedding
from config import *

from tokenize_code import tokenize_code
from gensim.models.doc2vec import Doc2Vec



class RetrivalDB:

    def __init__(self):
        self.embed = np.load('embed_tensors_clean_apr29.npy',
                             allow_pickle=True)
        self.kernel_ids = np.load('kernel_ids_apr29.npy',
                                  allow_pickle=True)
        self.idx_list = []
        idx = 0
        doc_list = []
        for doc in self.embed:
            self.idx_list.append(idx)
            idx += doc.shape[0]
            doc_list.append(doc)
        self.raw = np.concatenate(doc_list)

    def getDoc(self, raw_idx):
        if raw_idx < 0 or raw_idx >= self.raw.shape[0]:
            print('ERROR: out of index')
            return None
        first = 0
        last = len(self.idx_list) - 1
        midpoint = (first + last) // 2
        while True:
            midpoint = (first + last) // 2
            if self.idx_list[midpoint] <= raw_idx \
                and self.idx_list[midpoint + 1] > raw_idx:
                break
            else:
                if raw_idx < self.idx_list[midpoint]:
                    last = midpoint - 1
                else:
                    first = midpoint + 1
        kernel_id = self.kernel_ids[midpoint]
        return (kernel_id, raw_idx - self.idx_list[midpoint])

    def find_sim(self, embed, topn=10):
        result = np.einsum('ij,ij->i', self.raw, embed)
        rank = np.argsort(-result)[:topn]
        doc_list = [self.getDoc(r) for r in rank]
        return doc_list


if __name__ == '__main__':

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Load the RNN model
    gen = torch.load('./gen_saved/best_gen.pt').to(device)
    gen.eval()
    
    # Load the doc2vec model
    model = Doc2Vec.load("../doc2vec/model/notebook-doc2vec-model-apr24.model")
    db = RetrivalDB()

    while True:
        print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$')
        input('Update the sample.py and press Enter to continue...')

        # TODO: reads ipynb

        input_file = './sample.ipynb'
        embed_list = []
        f = codecs.open(input_file, 'r')
        source = f.read()

        y = json.loads(source)
        for x in y['cells']:
    #         print(x) 
            for x2 in x['source']:
                if x2[-1] != '\n':
                    x2 = x2 + '\n'
    #             print("Input is", x2)
                embed_list.append(torch.Tensor(model.infer_vector(tokenize_code(x2,'code'))).to(device))
        #print([e.shape for e in embed_list])
        predict_embed = gen.generate_embedding(embed_list)
        predict_embed = [embed.detach().cpu().numpy() for embed in predict_embed]
        
        doc_list = db.find_sim(predict_embed, topn=10)
        print(doc_list)
        file_path = '../doc2vec/data/sliced-notebooks-full-new'

$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$


Update the sample.py and press Enter to continue... 


[('ashrae-energy-prediction\\25644298_5', 1), ('ashrae-energy-prediction\\25840661_5', 1), ('ashrae-energy-prediction\\25644298_6', 1), ('ashrae-energy-prediction\\25644298_3', 1), ('ashrae-energy-prediction\\25644298_2', 1), ('ashrae-energy-prediction\\25644298_0', 1), ('ashrae-energy-prediction\\25840661_6', 1), ('ashrae-energy-prediction\\25644298_1', 1), ('ashrae-energy-prediction\\25644298_4', 1), ('ashrae-energy-prediction\\25840885_4', 1)]
$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$


Update the sample.py and press Enter to continue... 


[('quora-insincere-questions-classification\\8544113_3', 1), ('ashrae-energy-prediction\\25644298_5', 1), ('ashrae-energy-prediction\\25644298_6', 1), ('ashrae-energy-prediction\\25840661_5', 1), ('quora-insincere-questions-classification\\8544113_4', 2), ('quora-insincere-questions-classification\\8544113_5', 2), ('ashrae-energy-prediction\\25644298_0', 1), ('ashrae-energy-prediction\\25840661_6', 1), ('ashrae-energy-prediction\\25644298_2', 1), ('ashrae-energy-prediction\\25644298_3', 1)]
$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$


Update the sample.py and press Enter to continue... 


[('siim-isic-melanoma-classification\\41020093_5', 0), ('siim-isic-melanoma-classification\\41020093_4', 0), ('siim-isic-melanoma-classification\\38159585_0', 0), ('siim-isic-melanoma-classification\\38159585_4', 0), ('siim-isic-melanoma-classification\\41020093_2', 0), ('siim-isic-melanoma-classification\\41020093_3', 0), ('siim-isic-melanoma-classification\\38159585_1', 0), ('siim-isic-melanoma-classification\\38159585_2', 0), ('siim-isic-melanoma-classification\\41020093_0', 0), ('siim-isic-melanoma-classification\\38159585_3', 0)]
$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$


Update the sample.py and press Enter to continue... 


[('quora-insincere-questions-classification\\8544113_3', 1), ('quora-insincere-questions-classification\\8544113_4', 2), ('quora-insincere-questions-classification\\8544113_5', 2), ('quora-insincere-questions-classification\\8544113_10', 2), ('quora-insincere-questions-classification\\8544113_9', 2), ('quora-insincere-questions-classification\\8544113_11', 2), ('kobe-bryant-shot-selection\\1890199_49', 1), ('ashrae-energy-prediction\\25644298_5', 1), ('ashrae-energy-prediction\\25840661_5', 1), ('ashrae-energy-prediction\\25644298_6', 1)]
$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
