<a href="https://colab.research.google.com/github/RohanLone/word_embedding/blob/main/NLP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import io
import datetime
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_datasets as tfds

def get_data():
  (train_data,test_data),info = tfds.load('imdb_reviews/subwords8k',
                                        split = (tfds.Split.TRAIN,tfds.Split.TEST),
                                        with_info=True,as_supervised=True)
  encoder = info.features['text'].encoder
  padded_shapes = ([None],())
  train_batches = train_data.shuffle(1000).padded_batch(10,padded_shapes = padded_shapes)
  test_batches = test_data.padded_batch(10,padded_shapes = padded_shapes)

  return train_batches,test_batches,encoder

def  get_model(encoder):
  embedding_dim=16

  model = tf.keras.Sequential([
    layers.Embedding(encoder.vocab_size, embedding_dim, name="embedding"),
    layers.GlobalAveragePooling1D(),
    layers.Dense(16, activation='relu'),
    layers.Dropout(0.5),
    layers.Dense(1,activation='sigmoid')
  ])
  model.compile(optimizer = 'adam',loss = 'binary_crossentropy',metrics=['accuracy']) 
  
  return model


def retrieve_embeddings(model,encoder):

  out_vectors = io.open('vecs.csv','w',encoding = 'utf=8')
  out_metadata = io.open('meta.csv','w',encoding = 'utf=8')
  weights = model.layers[0].get_weights()[0]

  for num,word in enumerate(encoder.subwords):
    vec = weights[num+1]
    out_metadata.write(word+'\n')
    out_vectors.write('\t'.join([str(x) for x in vec])+ '\n')

  out_vectors.close()
  out_metadata.close() 

train_batches,test_batches,encoder = get_data()
model = get_model(encoder)
history = model.fit(train_batches,epochs = 10,validation_data = test_batches,
                    validation_steps = 20)

retrieve_embeddings(model,encoder)



Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [None]:
try:
  from google.colab import files
  files.download('vectors.tsv')
  files.download('metadata.tsv')
except Exception:
  pass