In [1]:
#
#
#
import tqdm
import collections
import more_itertools
import requests
import wandb
import torch




In [3]:


#
#
#
torch.manual_seed(42)


#
#
#
# r = requests.get("https://huggingface.co/datasets/ardMLX/text8/resolve/main/text8")
# with open("text8", "wb") as f: f.write(r.content)
with open('text8') as f: text8: str = f.read()



In [6]:

#
#
#
def preprocess(text: str) -> list[str]:
  text = text.lower()
  text = text.replace('.',  ' <PERIOD> ')
  text = text.replace(',',  ' <COMMA> ')
  text = text.replace('"',  ' <QUOTATION_MARK> ')
  text = text.replace(';',  ' <SEMICOLON> ')
  text = text.replace('!',  ' <EXCLAMATION_MARK> ')
  text = text.replace('?',  ' <QUESTION_MARK> ')
  text = text.replace('(',  ' <LEFT_PAREN> ')
  text = text.replace(')',  ' <RIGHT_PAREN> ')
  text = text.replace('--', ' <HYPHENS> ')
  text = text.replace('?',  ' <QUESTION_MARK> ')
  text = text.replace(':',  ' <COLON> ')
  words = text.split()
  stats = collections.Counter(words)
  words = [word for word in words if stats[word] > 5]
  return words


#
#
#
corpus: list[str] = preprocess(text8)
print(type(corpus)) # <class 'list'>
print(len(corpus))  # 16,680,599
print(corpus[:7])   # ['anarchism', 'originated', 'as', 'a', 'term', 'of', 'abuse']



<class 'list'>
16680599
['anarchism', 'originated', 'as', 'a', 'term', 'of', 'abuse']


In [7]:

#
#
#
def create_lookup_tables(words: list[str]) -> tuple[dict[str, int], dict[int, str]]:
  word_counts = collections.Counter(words)
  vocab = sorted(word_counts, key=lambda k: word_counts.get(k), reverse=True)
  int_to_vocab = {ii+1: word for ii, word in enumerate(vocab)}
  int_to_vocab[0] = '<PAD>'
  vocab_to_int = {word: ii for ii, word in int_to_vocab.items()}
  return vocab_to_int, int_to_vocab


#
#
#
words_to_ids, ids_to_words = create_lookup_tables(corpus)
tokens = [words_to_ids[word] for word in corpus]
print(type(tokens)) # <class 'list'>
print(len(tokens))  # 16,680,599
print(tokens[:7])   # [5234, 3081, 12, 6, 195, 2, 3134]


#
#
#
print(ids_to_words[5234])        # anarchism
print(words_to_ids['anarchism']) # 5234
print(words_to_ids['have'])      # 3081
print(len(words_to_ids))         # 63,642



<class 'list'>
16680599
[5234, 3081, 12, 6, 195, 2, 3134]
anarchism
5234
39
63642


In [19]:
words_to_ids['dog']

1902

In [40]:

#
#
#
class SkipGramFoo(torch.nn.Module):
  def __init__(self, voc, emb, _):
    super().__init__()
    self.emb = torch.nn.Embedding(num_embeddings=voc, embedding_dim=emb)
    self.ffw = torch.nn.Linear(in_features=emb, out_features=voc, bias=False)
    self.sig = torch.nn.Sigmoid()

  def forward(self, inpt, trgs, rand):
    emb = self.emb(inpt)
    ctx = self.ffw.weight[trgs]
    rnd = self.ffw.weight[rand]
    out = torch.bmm(ctx, emb.unsqueeze(-1)).squeeze()
    rnd = torch.bmm(rnd, emb.unsqueeze(-1)).squeeze()
    out = self.sig(out)
    rnd = self.sig(rnd)
    pst = -out.log().mean()
    ngt = -(1 - rnd + 10**(-3)).log().mean()
    return pst + ngt


#
#
#
embed_dim = 64
initial_lr = 0.01
arch = 'SkipGramFoo'
args = (len(words_to_ids), embed_dim, 2)
mFoo = SkipGramFoo(*args)
print('mFoo', sum(p.numel() for p in mFoo.parameters()))
opFoo = torch.optim.Adam(mFoo.parameters(), lr=initial_lr)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



mFoo 8146176


In [41]:

#
#
#
epochs = 1
windows = list(more_itertools.windowed(tokens, 3))
inputs = [w[1] for w in windows]
targets = [[w[0], w[2]] for w in windows]
input_tensor = torch.LongTensor(inputs)
target_tensor = torch.LongTensor(targets)
dataset = torch.utils.data.TensorDataset(input_tensor, target_tensor)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1024, shuffle=True)


#
#
#
wandb.init(project='mlx6-word2vec'
           ,config={
        "learning_rate": initial_lr,
        "architecture": arch,
        "dataset": "text8",
        "epochs": epochs,
    }
           , name='mFoo',)

mFoo.to(device)
for epoch in range(epochs):
  prgs = tqdm.tqdm(dataloader, desc=f"Epoch {epoch+1}", leave=False)
  for inpt, trgs in prgs:
    inpt, trgs = inpt.to(device), trgs.to(device)
    rand = torch.randint(0, len(words_to_ids), (inpt.size(0), 2)).to(device)
    opFoo.zero_grad()
    loss = mFoo(inpt, trgs, rand)
    loss.backward()
    opFoo.step()
    wandb.log({'loss': loss.item()})



                                                               

In [37]:
# check embedding before training
word_of_interest = 'cat'

# select word of interest
word_id = words_to_ids[word_of_interest]
start_embed = torch.nn.Embedding(len(words_to_ids), embedding_dim=embed_dim)
word_embed = start_embed.weight.data[word_id]  # shape: [64]

# Get all embeddings (vocab_size x 64)
all_embeds = start_embed.weight.data  # shape: [vocab_size, 64]

# Compute dot products (matrix multiplication)
before_dot_products = torch.matmul(all_embeds, word_embed)  # shape: [vocab_size]

# Get top 10 most similar words
top_k = 10
values, indices = torch.topk(before_dot_products, k=top_k)

# Convert indices to words
before_similar_words = [ids_to_words[idx.item()] for idx in indices]

    
# check embedding after training
final_embeddings = mFoo.emb.weight.data  # Shape: [vocab_size, embedding_dim]
final_word_embed = final_embeddings[word_id]  # shape: [64]

# Compute dot products (matrix multiplication)
after_dot_products = torch.matmul(final_embeddings, final_word_embed)  # shape: [vocab_size]

# Get top 10 most similar words
values, indices = torch.topk(after_dot_products, k=top_k)

# Convert indices to words
after_similar_words = [ids_to_words[idx.item()] for idx in indices]


print("Top similar words to '"+word_of_interest+"' before training:")
for word, score in zip(before_similar_words, values):
    print(f"{word}: {score:.3f}")
print('')
print("Top similar words to '"+word_of_interest+"' after training:")
for word, score in zip(after_similar_words, values):
    print(f"{word}: {score:.3f}")

Top similar words to 'cat' before training:
cat: 122.651
tricolour: 43.696
offends: 40.661
andromache: 40.570
apparitions: 39.769
conquered: 39.643
rota: 39.596
zweites: 39.500
phocidae: 38.865
konkan: 38.744

Top similar words to 'cat' after training:
cat: 122.651
gassan: 43.696
connor: 40.661
alamos: 40.570
seabird: 39.769
happen: 39.643
morden: 39.596
medicis: 39.500
rivaled: 38.865
utter: 38.744


In [38]:

#
#
#
print('Saving...')
torch.save(mFoo.state_dict(), './weights.pt')
print('Uploading...')
artifact = wandb.Artifact('model-weights', type='model')
artifact.add_file('./weights.pt')
wandb.log_artifact(artifact)
print('Done!')
wandb.finish()


Saving...
Uploading...
Done!


0,1
loss,▄▃▃█▄▃▃▂▃▃▂▂▃▂▁▁▂▂▂▂▁▁▂▂▂▁▁▂▃▂▂▂▁▃▂▁▃▂▂▂

0,1
loss,0.3621
