In [2]:
import torch.nn as nn
import torch

In [3]:
class WordEmbeddingLayer(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(WordEmbeddingLayer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.linear = nn.Linear(embedding_dim, vocab_size, bias=False)

    def forward(self, x):
        return self.linear(self.embedding(x).mean(dim=1))
    
    def get_embedding(self,idx):
        return self.embedding(torch.tensor([idx]))

In [1]:
# read the data from the file and load each line into a list
def read_data(filename):
    with open(filename) as f:
        data = f.readlines()
        data = [line.replace('\n','<eos>') for line in data]
    return data

sentences = read_data("sentences.txt")
print(sentences)
print('no of sentences:',len(sentences))

['The oval is red.<eos>', 'The star is white.<eos>', 'This grape is black.<eos>', 'Children dance a lot.<eos>', 'I love banana juice.<eos>', 'mangos are healthy to eat.<eos>', 'Draw a green oval.<eos>', 'I like green fruits like cherry.<eos>', 'A red cat is rare.<eos>', 'A diamond has many sides.<eos>', 'I love banana juice.<eos>', 'The elephant likes to dance.<eos>', 'A blue lion is rare.<eos>', 'The lion likes to fly.<eos>', 'I love mango juice.<eos>', 'elephants are brown.<eos>', 'I like purple fruits like grape.<eos>', 'The fox is run.<eos>', 'The bear is sit.<eos>', 'Children climb a lot.<eos>', 'A heart has many sides.<eos>', 'This strawberry is black.<eos>', 'Draw a orange hexagon.<eos>', 'The fish is sleep.<eos>', 'I saw a red fish.<eos>', 'The red diamond looks beautiful.<eos>', 'A hexagon has many sides.<eos>', 'The bear likes to dance.<eos>', 'I love watermelon juice.<eos>', 'The triangle looks perfect.<eos>', 'The tiger is jump.<eos>', 'Draw a pink hexagon.<eos>', 'I saw a 

In [6]:
import re

def tokenize(text):
    return re.findall(r'\b\w+\b', text.lower())
tokens = []
for i in list(map(tokenize, sentences)):
    tokens.extend(i)
tokens = list(set(tokens))
print("no of tokens",len(tokens))

no of tokens 93


In [7]:
word2idx = {word: i for i, word in enumerate(tokens)}
idx2word = {i: word for i, word in enumerate(tokens)}

In [8]:
wordEmbeddings = WordEmbeddingLayer(len(tokens), 3)
def get_word_embedding(word):
    return wordEmbeddings.get_embedding(word2idx[word])

def get_embedding_to_idx(idx):
    return wordEmbeddings.get_embedding(idx)

In [9]:
context = []
target = []
for i in sentences:
    words = re.findall(r'\b\w+\b', i.lower())
    for j in range(1,len(words)):
        context.append(list(map(lambda x:word2idx[x],words[:j])))
        target.append(word2idx[words[j]])

from torch.nn.utils.rnn import pad_sequence

context = [torch.tensor(c) for c in context]
target = torch.tensor(target)

context = pad_sequence(context, batch_first=True, padding_value=0)

train_data = list(zip(context,target))

In [10]:
from torch.utils.data import Dataset, DataLoader

class WordEmbeddingDataset(Dataset):
    def __init__(self, data, target):
        self.data = data
        self.target = target
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx] , self.target[idx]

In [11]:
dataset = WordEmbeddingDataset(context, target)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)


In [12]:
criteria = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(wordEmbeddings.parameters(), lr=0.01)

def train(epoch):
    for i in range(epoch):
        total_loss = 0
        for context,target in dataloader:
            output = wordEmbeddings(context)
            loss = criteria(output,target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {i}, Loss: {total_loss / len(dataloader)}")

In [23]:
train(3000)

Epoch 0, Loss: 0.8159064501523972
Epoch 1, Loss: 0.8172693252563477
Epoch 2, Loss: 0.8155029181923185
Epoch 3, Loss: 0.8138896503618785
Epoch 4, Loss: 0.8111737774951118
Epoch 5, Loss: 0.8157029322215489
Epoch 6, Loss: 0.8063665947743824
Epoch 7, Loss: 0.8151538840362004
Epoch 8, Loss: 0.8185347084488187
Epoch 9, Loss: 0.8191166392394474
Epoch 10, Loss: 0.818647529397692
Epoch 11, Loss: 0.820947425706046
Epoch 12, Loss: 0.8150975661618369
Epoch 13, Loss: 0.8156063386372158
Epoch 14, Loss: 0.8200003504753113
Epoch 15, Loss: 0.8206850226436343
Epoch 16, Loss: 0.8144237569400242
Epoch 17, Loss: 0.8180026518447059
Epoch 18, Loss: 0.8217268522296634
Epoch 19, Loss: 0.8147914558649063
Epoch 20, Loss: 0.8250282321657453
Epoch 21, Loss: 0.8144305646419525
Epoch 22, Loss: 0.8175288353647504
Epoch 23, Loss: 0.8121400943824223
Epoch 24, Loss: 0.8110355458089283
Epoch 25, Loss: 0.8118887799126762
Epoch 26, Loss: 0.8115276651723045
Epoch 27, Loss: 0.8131700669016156
Epoch 28, Loss: 0.81511097507817

In [25]:
torch.save(wordEmbeddings.state_dict(), "3d_word_embeddings.pth")

In [28]:
import plotly.graph_objects as go


fig = go.Figure()

x, y, z, texts = [], [], [], []

for i in range(len(tokens)):
    embedding = get_embedding_to_idx(i)
    num_emd = embedding.detach().numpy()
    x.append(num_emd[0][0])
    y.append(num_emd[0][1])
    z.append(num_emd[0][2])
    texts.append(idx2word[i]) 

fig.add_trace(go.Scatter3d(
    x=x, y=y, z=z,
    mode='markers+text',  
    text=texts,  # Annotations for each point
    textposition='top center', 
    marker=dict(size=5, color='red'),
))

fig.update_layout(
    scene=dict(
        xaxis_title='X-axis',
        yaxis_title='Y-axis',
        zaxis_title='Z-axis'
    ),
    title="Interactive 3D Point Plot",
)

fig.show()
fig.write_html("simple_3d_we_outputs.html")
