<a href="https://colab.research.google.com/github/Santosh-Gupta/NaturalLanguageRecommendations/blob/srihari-dev/notebooks/model_debug.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%tensorflow_version 2.x

UsageError: Line magic function `%tensorflow_version` not found.


In [2]:
import tensorflow as tf
from tqdm.notebook import tqdm
import os

print('TensorFlow:', tf.__version__)

TensorFlow: 2.1.0-rc0


In [3]:
batch_size  = 8
embedding_dim = 512
autotune = tf.data.experimental.AUTOTUNE

In [4]:
def get_random_title():
    return tf.random.uniform(shape=[512], maxval=200, dtype=tf.int32)

def get_random_citation():
    vector = tf.random.uniform(shape=[embedding_dim], minval=-1, maxval=1, dtype=tf.float32)
    normed_vector = tf.math.l2_normalize(vector)
    return normed_vector

def generate_sample():
    title = get_random_title()
    posCitations = get_random_citation()
    return title, posCitations

In [5]:
class TFrecordWriter:
    def __init__(self,
                 n_samples,
                 n_shards,
                 output_dir='',
                 prefix=''):
        self.n_samples = n_samples
        self.n_shards = n_shards
        self.step_size = self.n_samples//self.n_shards + 1
        self.prefix = prefix
        self.output_dir = output_dir
        self.buffer = []
        self.file_count = 1
        
    def make_example(self, title, vector):
        feature = {
            'title': tf.train.Feature(int64_list=tf.train.Int64List(value=title)),
            'citation': tf.train.Feature(float_list=tf.train.FloatList(value=vector))
        }
        return tf.train.Example(features=tf.train.Features(feature=feature))
        
    def write_tfrecord(self, tfrecord_path):
        print('writing {} samples in {}'.format(len(self.buffer), tfrecord_path))
        with tf.io.TFRecordWriter(tfrecord_path) as writer:
            for (title, vector) in tqdm(self.buffer):
                example = self.make_example(title, vector)
                writer.write(example.SerializeToString())
    
    def push(self, title, vector):
        self.buffer.append([title, vector])
        if len(self.buffer) == self.step_size:
            fname = self.prefix + '_000' + str(self.file_count) + '.tfrecord'
            tfrecord_path = os.path.join(self.output_dir, fname)
            self.write_tfrecord(tfrecord_path)
            self.clear_buffer()
            self.file_count += 1
            
    def flush_last(self):
        fname = self.prefix + '_000' + str(self.file_count) + '.tfrecord'
        tfrecord_path = os.path.join(self.output_dir, fname)
        self.write_tfrecord(tfrecord_path)
            
    def clear_buffer(self):
        self.buffer = []

In [6]:
tfrecord_writer = TFrecordWriter(1000, 16, 'tfrecords', 'train')

In [7]:
for i in range(1000):
    title, vector = generate_sample()
    tfrecord_writer.push(title, vector)
tfrecord_writer.flush_last()

writing 63 samples in tfrecords/train_0001.tfrecord


HBox(children=(IntProgress(value=0, max=63), HTML(value='')))


writing 63 samples in tfrecords/train_0002.tfrecord


HBox(children=(IntProgress(value=0, max=63), HTML(value='')))


writing 63 samples in tfrecords/train_0003.tfrecord


HBox(children=(IntProgress(value=0, max=63), HTML(value='')))


writing 63 samples in tfrecords/train_0004.tfrecord


HBox(children=(IntProgress(value=0, max=63), HTML(value='')))


writing 63 samples in tfrecords/train_0005.tfrecord


HBox(children=(IntProgress(value=0, max=63), HTML(value='')))


writing 63 samples in tfrecords/train_0006.tfrecord


HBox(children=(IntProgress(value=0, max=63), HTML(value='')))


writing 63 samples in tfrecords/train_0007.tfrecord


HBox(children=(IntProgress(value=0, max=63), HTML(value='')))


writing 63 samples in tfrecords/train_0008.tfrecord


HBox(children=(IntProgress(value=0, max=63), HTML(value='')))


writing 63 samples in tfrecords/train_0009.tfrecord


HBox(children=(IntProgress(value=0, max=63), HTML(value='')))


writing 63 samples in tfrecords/train_00010.tfrecord


HBox(children=(IntProgress(value=0, max=63), HTML(value='')))


writing 63 samples in tfrecords/train_00011.tfrecord


HBox(children=(IntProgress(value=0, max=63), HTML(value='')))


writing 63 samples in tfrecords/train_00012.tfrecord


HBox(children=(IntProgress(value=0, max=63), HTML(value='')))


writing 63 samples in tfrecords/train_00013.tfrecord


HBox(children=(IntProgress(value=0, max=63), HTML(value='')))


writing 63 samples in tfrecords/train_00014.tfrecord


HBox(children=(IntProgress(value=0, max=63), HTML(value='')))


writing 63 samples in tfrecords/train_00015.tfrecord


HBox(children=(IntProgress(value=0, max=63), HTML(value='')))


writing 55 samples in tfrecords/train_00016.tfrecord


HBox(children=(IntProgress(value=0, max=55), HTML(value='')))


