In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F

## Define Sample Data

In [2]:
corpus = ["apple banana fruit", "banana apple fruit", "banana fruit apple",
                 "dog cat animal", "cat animal dog", "cat dog animal"]

In [3]:
corpus = [sent.split(" ") for sent in corpus]
corpus

[['apple', 'banana', 'fruit'],
 ['banana', 'apple', 'fruit'],
 ['banana', 'fruit', 'apple'],
 ['dog', 'cat', 'animal'],
 ['cat', 'animal', 'dog'],
 ['cat', 'dog', 'animal']]

In [4]:
#extracting unique words
flatten = lambda l: [item for sublist in l for item in sublist]
vocab = list(set(flatten(corpus)))
vocab

['cat', 'animal', 'apple', 'banana', 'fruit', 'dog']

In [5]:
#appending number for each word in the set
word2index = {w: i for i, w in enumerate(vocab)}
print(word2index)

{'cat': 0, 'animal': 1, 'apple': 2, 'banana': 3, 'fruit': 4, 'dog': 5}


In [6]:
#vocab size
voc_size = len(vocab)
print(voc_size)

6


In [8]:
#append UNK
vocab.append('<UNK>')

In [9]:
vocab

['cat', 'animal', 'apple', 'banana', 'fruit', 'dog', '<UNK>', '<UNK>']

In [10]:
#defining a method to get the word through index
index2word = {v:k for k, v in word2index.items()} 

## Data Preparation

In [11]:
def random_batch(batch_size, word_sequence):
    
    # Make skip gram of one size window
    skip_grams = []
    for sent in corpus:
        for i in range(1, len(sent) - 1):
            target = word2index[sent[i]]
            context = [word2index[sent[i - 1]], word2index[sent[i + 1]]]
            for w in context:
                skip_grams.append([target, w])
    
    random_inputs = []
    random_labels = []
    random_index = np.random.choice(range(len(skip_grams)), batch_size, replace=False) #randomly pick without replacement
        
    for i in random_index:
        random_inputs.append([skip_grams[i][0]])  # target, e.g., 2
        random_labels.append([skip_grams[i][1]])  # context word, e.g., 3
            
    return np.array(random_inputs), np.array(random_labels)

In [12]:
#testing the method
batch_size = 2 # mini-batch size
input_batch, target_batch = random_batch(batch_size, corpus)

print("Input: ",  input_batch)
print("Target: ", target_batch)

Input:  [[3]
 [2]]
Target:  [[4]
 [3]]


In [13]:
input_batch.shape, target_batch.shape

((2, 1), (2, 1))