In [None]:
# Open the file in read mode
with open('news_corpus.txt', 'r') as file:
  content = file.read().lower()
  words = content.split()

words = [word for word in words if word.isalpha() and 'Ã¦' not in word]

In [None]:
vocab = set(words)
vocab_size = len(vocab)

In [None]:
vocab_size

33482

In [None]:
word_to_index = {word:index for index, word in enumerate(vocab)}
index_to_word = {index:word for index, word in enumerate(vocab)}

In [None]:
# Create our dataset using a combination of context and target for CBOW training
data = []
for i in range(2, len(words)-2):
  context = [words[i-2], words[i-1],
             words[i+1], words[i+2]]
  target = words[i]
  data.append((context, target))

In [None]:
import torch
import torch.nn as nn

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
device

device(type='cuda')

In [None]:
def make_context_vector(context, word_to_index):
  indices = [word_to_index[w] for w in context]
  return torch.tensor(indices, dtype=torch.long, device=device)

In [None]:
EMBEDDING_DIM = 100

# Class to build CBOW model
class CBOW(torch.nn.Module):
  def __init__(self, vocab_size, embedding_dim):
    super(CBOW, self).__init__()

    self.embeddings = nn.Embedding(vocab_size, embedding_dim)
    self.linear1 = nn.Linear(embedding_dim, 128)
    self.activation_function1 = nn.ReLU()

    self.linear2 = nn.Linear(128, vocab_size)
    self.activation_function2 = nn.LogSoftmax(dim=-1)

  def forward(self, inputs):
    embeds = sum(self.embeddings(inputs)).view(1,-1)
    out = self.linear1(embeds)
    out = self.activation_function1(out)
    out = self.linear2(out)
    out = self.activation_function2(out)
    return out

  # EXTRA
  def get_word_embedding(self, word):
    word_index = torch.tensor(word_to_index[word], device=device)
    return self.embeddings(word_index).view(1,-1)

In [None]:
model = CBOW(vocab_size, EMBEDDING_DIM).to(device)

In [None]:
loss_function = nn.NLLLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

In [None]:
# Training
for epoch in range(50):
  total_loss = 0

  for context, target in data:
    context_vector = make_context_vector(context, word_to_index)
    model.zero_grad()
    log_probs = model(context_vector)
    loss = loss_function(log_probs, torch.tensor([word_to_index[target]], device=device))
    total_loss += loss
    loss.backward()
    optimizer.step()

  print(total_loss)

In [None]:
# Create our dataset using a combination of context and target for CBOW training
data = []
for i in range(2, len(words)-2):
  context = [words[i-2], words[i-1],
             words[i+1], words[i+2]]
  target = words[i]
  data.append((context, target))

# Group the data into batches of 64
batch_size = 64
batches = [data[i:i+batch_size] for i in range(0, len(data), batch_size)]

class CBOW(torch.nn.Module):
  def __init__(self, vocab_size, embedding_dim):
    super(CBOW, self).__init__()

    self.embeddings = nn.Embedding(vocab_size, embedding_dim)
    self.linear1 = nn.Linear(embedding_dim, 128)
    self.activation_function1 = nn.ReLU()

    self.linear2 = nn.Linear(128, vocab_size)
    self.activation_function2 = nn.LogSoftmax(dim=-1)

  def forward(self, inputs):
    # Sum the embeddings of the context words
    embeds = torch.sum(self.embeddings(inputs), dim=1)
    out = self.linear1(embeds)
    out = self.activation_function1(out)
    out = self.linear2(out)
    out = self.activation_function2(out)
    return out

model = CBOW(vocab_size, EMBEDDING_DIM).to(device)
loss_function = nn.NLLLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

# Train the model
for epoch in range(50):
  total_loss = 0

  for batch in batches:
    context, target = zip(*batch)
    context_vectors = torch.stack([make_context_vector(c, word_to_index) for c in context]).to(device)
    target_tensors = torch.tensor([word_to_index[t] for t in target], device=device)

    model.zero_grad()
    log_probs = model(context_vectors)
    loss = loss_function(log_probs, target_tensors)
    total_loss += loss
    loss.backward()
    optimizer.step()

  print(total_loss)


In [None]:
# Testing
context = ['anne','green', 'and', 'out']
context_vector = make_context_vector(context, word_to_index)
output = model(context_vector)

In [None]:
# Print result
print(f'Context: {context}\n')
print(f'Prediction: {index_to_word[torch.argmax(output[0]).item()]}')

Context: ['anne', 'green', 'and', 'out']

Prediction: of


In [None]:
print(f'embedding:{model.get_word_embedding("good")}')

embedding:tensor([[ 0.0344,  0.3045, -0.4319, -1.7686,  0.1896, -0.1149,  1.3540, -2.0016,
          1.3115,  0.6336, -0.6434,  0.1296, -0.6239, -0.1096, -1.1448,  0.7652,
          0.7295,  0.3042, -0.6074,  0.3363,  1.3657,  0.5275, -0.2030,  1.4025,
          0.6864, -1.6390,  0.4457, -1.4422, -0.0918, -1.1462,  1.1658, -0.3711,
         -1.2876,  0.6065, -0.3170,  0.2223, -0.0087,  0.1301,  0.1048, -0.0708,
          1.6249,  0.2964, -0.9412,  0.8288, -1.3666, -0.4472,  0.3989,  0.4394,
          0.0505, -0.7859,  0.1017, -0.4291, -0.0473,  0.0533,  0.3781,  0.7061,
          0.6122, -0.6028, -0.4614,  0.4978,  1.3796, -0.5250, -1.1612, -0.0190,
         -0.5980, -0.6977,  0.5652,  0.0690, -1.0179, -0.1406, -1.3673, -0.7538,
          0.6141,  2.5406,  0.0266, -0.7983,  0.8883, -0.0464, -1.4198, -0.3316,
          0.0727, -0.8053,  0.0274, -2.4423,  0.0520,  2.0904, -1.3377,  0.1818,
          1.2972, -1.2921, -0.1637,  1.3437, -0.0802,  1.6403,  0.2230, -0.5634,
         -1.0331, 

In [None]:
print(f'embedding:{model.get_word_embedding("great")}')

embedding:tensor([[ 4.1602e-01,  5.5634e-01,  2.2157e+00,  1.0961e+00, -2.8320e-01,
          1.7972e+00,  2.1557e+00, -6.9526e-01, -4.3523e-01, -6.7242e-01,
          1.0329e-02, -3.7228e-01,  1.8112e+00,  4.7911e-01, -1.1349e+00,
         -9.1864e-02,  3.6556e-01, -9.3028e-01,  1.8752e-02,  1.0644e+00,
          1.2995e+00,  3.5984e-01, -1.3454e+00, -3.0930e-02,  9.0014e-02,
         -9.8215e-01, -1.2702e+00,  2.6569e-01, -1.5493e+00,  1.7160e+00,
         -9.5832e-01,  3.8075e-01, -9.5799e-01, -7.3311e-01, -8.7787e-01,
          2.0520e+00, -2.1170e+00,  1.1215e-01, -6.0816e-01,  6.6146e-01,
         -4.6952e-01,  4.8175e-01,  1.6175e+00, -6.2025e-01, -4.1780e-02,
          8.8671e-01, -1.1293e+00, -5.7480e-01,  1.3292e-01,  1.2447e+00,
         -7.0102e-01, -6.8578e-01,  4.9787e-01,  4.3137e-01,  9.2317e-01,
          1.7474e-01,  1.7985e-01,  1.4353e-02, -5.7185e-01,  3.1124e-01,
          6.7811e-01,  7.5656e-01,  2.7621e-01,  6.2838e-01, -5.2297e-01,
         -1.8609e+00, -6.868