In [2]:
!pip install numpy==1.23.5
!pip install --upgrade gensim


Collecting gensim
  Using cached gensim-4.3.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.1 kB)
Using cached gensim-4.3.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (26.7 MB)
Installing collected packages: gensim
Successfully installed gensim-4.3.3


In [1]:
import numpy as np
import gensim.downloader as api
from gensim.models import Word2Vec

glove_path = api.load("glove-wiki-gigaword-100", return_path=True)
sentences  = api.load("text8")




In [2]:
def train_w2v(sentences, sg, vector_size=100, window=5, min_count=5, epochs=5, workers=4):

    model = Word2Vec(vector_size=vector_size,
                     window=window,
                     min_count=min_count,
                     sg=sg,               # 1 = Skip‑gram || 0 = CBOW
                     workers=workers)

    model.build_vocab(sentences)

    model.wv.vectors_lockf = np.ones(len(model.wv), dtype=np.float32)

    model.wv.intersect_word2vec_format(glove_path, binary=0)


    model.train(sentences, total_examples=model.corpus_count, epochs=epochs)

    return model

In [5]:
def train_skipgram(sentences, **kw):
    return train_w2v(sentences, sg=1, **kw)


def train_cbow(sentences, **kw):
    return train_w2v(sentences, sg=0, **kw)

In [6]:
print("Training Skip‑gram …")
skipgram_model = train_skipgram(sentences)

print("Training CBOW …")
cbow_model = train_cbow(sentences)

Training Skip‑gram …
Training CBOW …


In [7]:
pred = "computer"
print("Skip‑gram neighbours of ", pred)
print(skipgram_model.wv.most_similar(pred, topn=5))

Skip‑gram neighbours of  computer
[('computers', 0.8751984238624573), ('software', 0.8373122215270996), ('technology', 0.7642159461975098), ('pc', 0.7366448640823364), ('hardware', 0.7290390729904175)]


In [31]:
pred = "scorpion"
print("Skip‑gram neighbours of ", pred)
print(skipgram_model.wv.most_similar(pred, topn=5))

Skip‑gram neighbours of  scorpion
[('spider', 0.6373819708824158), ('snake', 0.6178805828094482), ('venom', 0.5878459215164185), ('dragonfly', 0.5768774151802063), ('tortoise', 0.5757177472114563)]


In [34]:
pred = "dam"
print("Skip‑gram neighbours of ", pred)
print(skipgram_model.wv.most_similar(pred, topn=5))

Skip‑gram neighbours of  dam
[('reservoir', 0.8164873719215393), ('dams', 0.7653598189353943), ('hydroelectric', 0.7307214736938477), ('gorges', 0.7190767526626587), ('embankment', 0.6634266376495361)]


In [25]:
ctx = ["football", "is", "entertaining", "sport"]
print("CBOW neighbours of", ctx)
print(cbow_model.wv.most_similar(positive=ctx, topn=5))

CBOW neighbours of ['football', 'is', 'entertaining', 'sport']
[('sports', 0.7905215620994568), ('well', 0.7405361533164978), ('soccer', 0.7400915026664734), ('play', 0.7296680808067322), ('good', 0.7275350689888)]


In [27]:
vecs = [cbow_model.wv.get_vector(w, norm=True) for w in ctx if w in cbow_model.wv]
query = np.sum(vecs, axis=0)
query /= np.linalg.norm(query)
raw = cbow_model.wv.similar_by_vector(query, topn=len(ctx) + 5)
filtered = [(w, s) for w, s in raw if w not in ctx][:5]
print("via NP manually")
print(filtered)

via NP manually
[('sports', 0.7905215620994568), ('well', 0.7405362129211426), ('soccer', 0.7400915026664734), ('play', 0.7296680808067322), ('good', 0.7275350689888)]


In [29]:
ctx = [ "magic", "magician", "play", "tricks"]
print("CBOW neighbours of", ctx)
print(cbow_model.wv.most_similar(positive=ctx, topn=5))

CBOW neighbours of ['magic', 'magician', 'play', 'tricks']
[('trick', 0.7602158188819885), ('magical', 0.6949593424797058), ('playing', 0.670497715473175), ('luck', 0.6365276575088501), ('game', 0.6365230679512024)]


In [30]:
from numpy.linalg import norm

vecs = [cbow_model.wv.get_vector(w, norm=True) for w in ctx if w in cbow_model.wv]
query = np.sum(vecs, axis=0)
query /= np.linalg.norm(query)
raw = cbow_model.wv.similar_by_vector(query, topn=len(ctx) + 5)
filtered = [(w, s) for w, s in raw if w not in ctx][:5]
print("via NP manually")
print(filtered)

via NP manually
[('trick', 0.7602158188819885), ('magical', 0.6949593424797058), ('playing', 0.670497715473175), ('luck', 0.6365276575088501), ('game', 0.6365230679512024)]
