## 3.4 Теоретические вопросы: Дистрибутивная семантика

### Импорт необходимых библиотек

In [None]:
import numpy as np

### [Напишите функцию, которая генерирует обучающие примеры из текста (word2vec).](https://stepik.org/lesson/261476/step/2?unit=242225)

In [72]:
def generate_w2v_sgns_samples(text, window_size, vocab_size, ns_rate):
    """
    text - list of integer numbers - ids of tokens in text
    window_size - odd integer - width of window
    vocab_size - positive integer - number of tokens in vocabulary
    ns_rate - positive integer - number of negative tokens to sample per one positive sample

    returns list of training samples (CenterWord, CtxWord, Label)
    """

    def make_diag_mask(size, radius):
        idxs = np.arange(size)
        abs_idx_diff = np.abs(idxs - idxs.reshape(size, 1))
        mask = ((abs_idx_diff <= radius) & (abs_idx_diff > 0)).astype(int)
        return mask

    result = []
    mask = make_diag_mask(len(text), window_size // 2)

    for i, w in enumerate(text):
        for j, v in enumerate(mask[i]):
            if v:
                result.append([w, text[j], 1])
                for i in range(ns_rate):
                    result.append([w, np.random.randint(0, (vocab_size - 1)), 0])

    return result

In [73]:
test_text = [1, 0, 1, 0, 0, 5, 0, 3, 5, 5, 3, 0, 5, 0, 5, 2, 0, 1, 3]
test_window_size = 3
test_vocab_size = 6
test_ns_rate = 1
test_result = generate_w2v_sgns_samples(
    test_text, test_window_size, test_vocab_size, test_ns_rate
)

In [76]:
test_result

[[1, 0, 1],
 [1, 0, 0],
 [0, 1, 1],
 [0, 0, 0],
 [0, 1, 1],
 [0, 4, 0],
 [1, 0, 1],
 [1, 2, 0],
 [1, 0, 1],
 [1, 0, 0],
 [0, 1, 1],
 [0, 0, 0],
 [0, 0, 1],
 [0, 3, 0],
 [0, 0, 1],
 [0, 3, 0],
 [0, 5, 1],
 [0, 1, 0],
 [5, 0, 1],
 [5, 0, 0],
 [5, 0, 1],
 [5, 2, 0],
 [0, 5, 1],
 [0, 2, 0],
 [0, 3, 1],
 [0, 1, 0],
 [3, 0, 1],
 [3, 4, 0],
 [3, 5, 1],
 [3, 2, 0],
 [5, 3, 1],
 [5, 2, 0],
 [5, 5, 1],
 [5, 2, 0],
 [5, 5, 1],
 [5, 0, 0],
 [5, 3, 1],
 [5, 0, 0],
 [3, 5, 1],
 [3, 4, 0],
 [3, 0, 1],
 [3, 3, 0],
 [0, 3, 1],
 [0, 2, 0],
 [0, 5, 1],
 [0, 4, 0],
 [5, 0, 1],
 [5, 2, 0],
 [5, 0, 1],
 [5, 0, 0],
 [0, 5, 1],
 [0, 3, 0],
 [0, 5, 1],
 [0, 1, 0],
 [5, 0, 1],
 [5, 1, 0],
 [5, 2, 1],
 [5, 1, 0],
 [2, 5, 1],
 [2, 1, 0],
 [2, 0, 1],
 [2, 0, 0],
 [0, 2, 1],
 [0, 1, 0],
 [0, 1, 1],
 [0, 3, 0],
 [1, 0, 1],
 [1, 2, 0],
 [1, 3, 1],
 [1, 0, 0],
 [3, 1, 1],
 [3, 2, 0]]

### [Напишите функцию, обновляющую веса модели при получении одного обучающего примера (word2vec).](https://stepik.org/lesson/261476/step/3?unit=242225)

In [130]:
def update_w2v_weights(
    center_embeddings,
    context_embeddings,
    center_word,
    context_word,
    label,
    learning_rate,
):
    """
    center_embeddings - VocabSize x EmbSize
    context_embeddings - VocabSize x EmbSize
    center_word - int - identifier of center word
    context_word - int - identifier of context word
    label - 1 if context_word is real, 0 if it is negative
    learning_rate - float > 0 - size of gradient step
    """

    def sigmoid(x):
        return 1 / (1 + np.exp(-x))

    def grad(x, p, y):
        return x * np.array((p - y))

    center = center_embeddings[center_word]
    context = context_embeddings[context_word]
    prob = sigmoid(np.dot(center, context))

    step_center = learning_rate * grad(context, prob, label)
    step_context = learning_rate * grad(center, prob, label)

    center -= step_center
    context -= step_context

    return center, context

In [131]:
center_embeddings = [
    [0.3449417709491044, 0.6762047256081501, 0.9583446027893963],
    [0.6247126159157468, 0.22038323197740317, 0.29717611444948355],
    [0.9836099232994968, 0.3847689688960674, 0.033312247867206435],
    [0.4217704869846559, 0.0023859008971685025, 0.009686915033163657],
    [0.6933070658521228, 0.9705089533296152, 0.9189360293193337],
    [0.024858486425111903, 0.11331113152689753, 0.6492144300167894],
    [0.7861289466352543, 0.227319130535791, 0.8165251907260063],
    [0.7672181161105678, 0.04865001026002924, 0.07514404284170773],
]
context_embeddings = [
    [0.4628817426583818, 0.7747296319956671, 0.1374808935513827],
    [0.17026823169513283, 0.4094733988461122, 0.3175531656197459],
    [0.2910876746161247, 0.6340566555548147, 0.23158010794029804],
    [0.8449042648180852, 0.4796593509107806, 0.11278090182290745],
    [0.049097778744511156, 0.6254116250148337, 0.13038703647472905],
    [0.882545488649187, 0.6223076699449618, 0.1633041302523962],
    [0.6704032810194875, 0.941803340812521, 0.7358646489592193],
    [0.9875878745059805, 0.17935677165390562, 0.6798846454394736],
]
center_word = 2
context_word = 5
label = 0
learning_rate = 0.342405260598321

In [132]:
update_w2v_weights(
    center_embeddings,
    context_embeddings,
    center_word,
    context_word,
    label,
    learning_rate,
)

(array([ 0.75615844,  0.22438653, -0.00877484]),
 array([0.62904747, 0.5231442 , 0.15471883]))

### [Напишите функцию, которая генерирует обучающие примеры из текста (fasttext).](https://stepik.org/lesson/261476/step/4?unit=242225)

In [133]:
def generate_ft_sgns_samples(text, window_size, vocab_size, ns_rate, token2subwords):
    """
    text - list of integer numbers - ids of tokens in text
    window_size - odd integer - width of window
    vocab_size - positive integer - number of tokens in vocabulary
    ns_rate - positive integer - number of negative tokens to sample per one positive sample
    token2subwords - list of lists of int - i-th sublist contains list of identifiers of n-grams for token #i (list of subword units)

    returns list of training samples (CenterSubwords, CtxWord, Label)
    """

    def make_diag_mask(size, radius):
        idxs = np.arange(size)
        abs_idx_diff = np.abs(idxs - idxs.reshape(size, 1))
        mask = ((abs_idx_diff <= radius) & (abs_idx_diff > 0)).astype(int)
        return mask

    result = []
    mask = make_diag_mask(len(text), window_size // 2)
    word_cont = None

    for i, w in enumerate(text):
        for j, v in enumerate(mask[i]):
            if v:
                word_cont = list(set([w] + token2subwords[w]))
                result.append([word_cont, text[j], 1])
                for i in range(ns_rate):
                    result.append(
                        [word_cont, np.random.randint(0, (vocab_size - 1)), 0]
                    )

    return result

### [Напишите функцию, обновляющую веса модели при получении одного обучающего примера (fasttext).](https://stepik.org/lesson/261476/step/5?unit=242225)

In [215]:
def update_ft_weights(
    center_embeddings,
    context_embeddings,
    center_subwords,
    context_word,
    label,
    learning_rate,
):
    """
    center_embeddings - VocabSize x EmbSize
    context_embeddings - VocabSize x EmbSize
    center_subwords - list of ints - list of identifiers of n-grams contained in center word
    context_word - int - identifier of context word
    label - 1 if context_word is real, 0 if it is negative
    learning_rate - float > 0 - size of gradient step
    """

    def sigmoid(x):
        return 1 / (1 + np.exp(-x))

    def grad(x, p, y):
        return x * np.array((p - y))

    center_embeddings = np.array(center_embeddings)
    context_embeddings = np.array(context_embeddings)

    center = sum([center_embeddings[w_i] for w_i in center_subwords]) / len(
        center_subwords
    )

    context = context_embeddings[context_word]
    prob = sigmoid(np.dot(center, context))

    step_center = learning_rate * grad(context, prob, label) / len(center_subwords)
    step_context = learning_rate * grad(center, prob, label)

    for w_i in center_subwords:
        center_embeddings[w_i] -= step_center

    context_embeddings[context_word] -= step_context

    return center_embeddings, context_embeddings

In [216]:
update_ft_weights(
    [
        [
            0.07217140995735816,
            0.9807495045952024,
            0.5888650678318127,
            0.9419020475323008,
            0.9698687137771355,
        ],
        [
            0.17481764801167854,
            0.9598681333667267,
            0.8615416075076997,
            0.6649845254089604,
            0.14272822189820067,
        ],
        [
            0.695257160390079,
            0.6252124583357915,
            0.788572884360212,
            0.5407620598434707,
            0.4742760619803522,
        ],
        [
            0.3720755825170682,
            0.8734430653555122,
            0.29388553936147677,
            0.7833976055802006,
            0.11647446813597206,
        ],
        [
            0.4793503066165381,
            0.7731679392102295,
            0.6466062364447424,
            0.5834632727525674,
            0.16975097768580916,
        ],
        [
            0.46855676928071344,
            0.7440440871653314,
            0.5968916205486556,
            0.6949993371605877,
            0.9995564750677164,
        ],
        [
            0.3995517204225809,
            0.30217048674177027,
            0.6934836340605662,
            0.5025452046745376,
            0.43990420866402447,
        ],
        [
            0.6233285824044058,
            0.7510765715859197,
            0.8764982899024905,
            0.42892241183749247,
            0.9241569354174014,
        ],
        [
            0.21063022873083803,
            0.979366603599722,
            0.07879437255385402,
            0.7103116511451802,
            0.298121842692622,
        ],
        [
            0.7991181799927396,
            0.8700912396205017,
            0.4936455488806514,
            0.9306352063022928,
            0.671689987782089,
        ],
        [
            0.11245515636577097,
            0.2591385008756272,
            0.38393130144123977,
            0.5927928993875077,
            0.3343301767582757,
        ],
        [
            0.027340724019638274,
            0.15461071231349877,
            0.7955192467457007,
            0.050624838697975516,
            0.26136570172628426,
        ],
        [
            0.7825083895933859,
            0.9046538942978853,
            0.4559636175207443,
            0.733829258685726,
            0.022174763292638677,
        ],
        [
            0.6968176063951074,
            0.47974647096747125,
            0.8885207189970179,
            0.016167994434510558,
            0.13260182909882334,
        ],
        [
            0.5947903955259933,
            0.07459974351651177,
            0.11391699485528617,
            0.823474357110585,
            0.4918622459339238,
        ],
        [
            0.6272760016913231,
            0.2711994820963495,
            0.24338892914238242,
            0.7731707300677505,
            0.03720128542002399,
        ],
        [
            0.8640858092433228,
            0.027663971153230382,
            0.9271422334467209,
            0.37457369227183035,
            0.17413436429736662,
        ],
        [
            0.4878584763813121,
            0.5022845803948351,
            0.13899660663745628,
            0.8353408935052742,
            0.48314336609381436,
        ],
        [
            0.8197910105251979,
            0.5371430936015362,
            0.12965724315376936,
            0.06244349080403733,
            0.9558816248633216,
        ],
        [
            0.5929477505385994,
            0.36687167726065173,
            0.42925321480480627,
            0.8435274356179648,
            0.8550018469714032,
        ],
        [
            0.45785273815309,
            0.008764229829187009,
            0.6840407156586629,
            0.04831125277736026,
            0.14609911971743395,
        ],
        [
            0.1579479219010974,
            0.1298470924838635,
            0.8283362978065627,
            0.9140741421274726,
            0.7516395431217443,
        ],
        [
            0.01139316661353773,
            0.6980229640742956,
            0.45528806869472405,
            0.7653990849713008,
            0.24848012670857944,
        ],
        [
            0.8750941097872984,
            0.6964598870452183,
            0.6675389863133752,
            0.391939718013135,
            0.30592620271209714,
        ],
        [
            0.024161748164975072,
            0.6512328549928654,
            0.27784751504029503,
            0.32588414662648524,
            0.4073676483413957,
        ],
        [
            0.7372935688667617,
            0.9743689028772393,
            0.26179932035274445,
            0.3556999822154028,
            0.8234406534181563,
        ],
        [
            0.9358431512408416,
            0.0030942521035778325,
            0.7052198210371732,
            0.3494249594704901,
            0.06494462197366668,
        ],
        [
            0.027642224051125597,
            0.45820907093457997,
            0.6172763215932299,
            0.03520578036716404,
            0.05004091043245007,
        ],
    ],
    [
        [
            0.3619192935809462,
            0.7910582560833153,
            0.173840770588212,
            0.8486217599360419,
            0.09895998679198104,
        ],
        [
            0.9524670374363299,
            0.577316446205222,
            0.3348594666828074,
            0.7987547183235284,
            0.710457681490417,
        ],
        [
            0.8400820704952479,
            0.9414962586451427,
            0.08399082278691339,
            0.425927381574433,
            0.6304514720560764,
        ],
        [
            0.5331686510681622,
            0.2751366715811131,
            0.8329999135745643,
            0.2770290564458684,
            0.020564166091874392,
        ],
        [
            0.9852792048968001,
            0.922320208232837,
            0.7297936992308128,
            0.20212997935663524,
            0.5277458149323955,
        ],
        [
            0.43383566311415755,
            0.14151987203148808,
            0.3267585826852797,
            0.8796734627573763,
            0.14253685112772174,
        ],
        [
            0.24559727482999572,
            0.3015598034026842,
            0.12351719983998721,
            0.6141130319406622,
            0.9210871618079258,
        ],
        [
            0.21915908704207665,
            0.9809645232509783,
            0.8685879466971278,
            0.9956335594634693,
            0.0441562419906687,
        ],
        [
            0.24988758739587902,
            0.42298807118368675,
            0.01922872769211703,
            0.02806386746602596,
            0.2821901214584819,
        ],
        [
            0.43997555452635384,
            0.5078839449569567,
            0.812607950040521,
            0.9998014106280365,
            0.1559607489614684,
        ],
        [
            0.9092151190046189,
            0.5930002929595868,
            0.315159378929991,
            0.4052299042409616,
            0.984475831988958,
        ],
        [
            0.7836990450026143,
            0.002466529016497798,
            0.8465916260137056,
            0.7227126698344118,
            0.5087557482398855,
        ],
        [
            0.4125921074144525,
            0.5582795115000383,
            0.889307828978137,
            0.928416977596577,
            0.8437462138575066,
        ],
        [
            0.11810981794872477,
            0.07787452990697508,
            0.3907338451314212,
            0.6841828899516664,
            0.4547615738832046,
        ],
        [
            0.4977766315279062,
            0.09878866849137813,
            0.0622140049250518,
            0.9008881823827194,
            0.3694055807903669,
        ],
        [
            0.12415427540834822,
            0.01064247175537103,
            0.1439469061372417,
            0.43996173718103593,
            0.3846553735294024,
        ],
        [
            0.36544315427420426,
            0.6651402425072226,
            0.3837201693785094,
            0.54713466624535,
            0.6925194086063208,
        ],
        [
            0.8217730539154436,
            0.7380601103419114,
            0.4790971996703556,
            0.935248458815274,
            0.6385239169547122,
        ],
        [
            0.4884363834477089,
            0.783319748626155,
            0.018212966919229467,
            0.03662832627793777,
            0.03532160993715294,
        ],
        [
            0.6820505211290306,
            0.25769913167047753,
            0.9677388106523852,
            0.4471332422618759,
            0.7731319006564568,
        ],
        [
            0.3695513424667971,
            0.5118113495291988,
            0.1721439269100805,
            0.09451631327113852,
            0.8369170475041434,
        ],
        [
            0.7918542552021289,
            0.0245240901264403,
            0.6658133706965796,
            0.9740885323982209,
            0.02660284500887522,
        ],
        [
            0.5604137104962275,
            0.5643917632639455,
            0.6756476068355826,
            0.9466913679034125,
            0.21062462975598062,
        ],
        [
            0.7306868573812846,
            0.7573083135261555,
            0.9450278665003865,
            0.9649869335038909,
            0.1262321882978371,
        ],
        [
            0.6830284536315845,
            0.7383035166437748,
            0.7985226892860073,
            0.005247820534787007,
            0.6886083391552933,
        ],
        [
            0.6905561126225058,
            0.3220803445510755,
            0.8885006766287556,
            0.32709316933290455,
            0.9126547743770385,
        ],
        [
            0.26866358146648694,
            0.9355232286537734,
            0.5254946965960933,
            0.6487428023364232,
            0.9405298594379049,
        ],
        [
            0.33881123962516546,
            0.6820622877451537,
            0.3053828831926755,
            0.9229486901650673,
            0.5450270097149575,
        ],
    ],
    [3],
    1,
    1,
    0.8562235244377375,
)

(array([[0.07217141, 0.9807495 , 0.58886507, 0.94190205, 0.96986871],
        [0.17481765, 0.95986813, 0.86154161, 0.66498453, 0.14272822],
        [0.69525716, 0.62521246, 0.78857288, 0.54076206, 0.47427606],
        [0.50175945, 0.95204802, 0.33947858, 0.89215266, 0.21320737],
        [0.47935031, 0.77316794, 0.64660624, 0.58346327, 0.16975098],
        [0.46855677, 0.74404409, 0.59689162, 0.69499934, 0.99955648],
        [0.39955172, 0.30217049, 0.69348363, 0.5025452 , 0.43990421],
        [0.62332858, 0.75107657, 0.87649829, 0.42892241, 0.92415694],
        [0.21063023, 0.9793666 , 0.07879437, 0.71031165, 0.29812184],
        [0.79911818, 0.87009124, 0.49364555, 0.93063521, 0.67168999],
        [0.11245516, 0.2591385 , 0.3839313 , 0.5927929 , 0.33433018],
        [0.02734072, 0.15461071, 0.79551925, 0.05062484, 0.2613657 ],
        [0.78250839, 0.90465389, 0.45596362, 0.73382926, 0.02217476],
        [0.69681761, 0.47974647, 0.88852072, 0.01616799, 0.13260183],
        [0.5947904 ,

### [Построить матрицу совместной встречаемости токенов (GloVe).](https://stepik.org/lesson/261476/step/6?unit=242225)

In [242]:
import scipy.sparse

In [252]:
def generate_coocurrence_matrix(texts, vocab_size):
    """
    texts - list of lists of ints - i-th sublist contains identifiers of tokens in i-th document
    vocab_size - int - size of vocabulary
    returns scipy.sparse.dok_matrix
    """

    def calc_number_occurrences(w1, w2):
        count = 0
        for text in texts:
            if w1 in text and w2 in text:
                count += 1
        return count

    m = np.zeros((vocab_size, vocab_size))

    for idx, j in np.ndenumerate(m):
        if idx[0] != idx[1]:
            m[idx] = calc_number_occurrences(idx[0], idx[1])

    return scipy.sparse.dok_matrix(m)

In [253]:
generate_coocurrence_matrix(
    [[0, 2, 2, 2, 0, 0], [1, 1, 2, 1, 1], [2, 2, 1, 1]], 3
).toarray()

array([[0., 0., 1.],
       [0., 0., 2.],
       [1., 2., 0.]])

### [Напишите функцию, которая обновляет параметры модели градиентным спуском (GloVe).](https://stepik.org/lesson/261476/step/7?unit=242225)

In [623]:
def update_glove_weights(x, w, d, alpha, max_x, learning_rate):
    """
    x - square integer matrix VocabSize x VocabSize - coocurrence matrix
    w - VocabSize x EmbSize - first word vectors
    d - VocabSize x EmbSize - second word vectors
    alpha - float - power in weight smoothing function f
    max_x - int - maximum coocurrence count in weight smoothing function f
    learning_rate - positive float - size of gradient step
    """

    f_w = np.where(x <= max_x, (x / max_x) ** alpha, 1)
    main_part_derivative = f_w * (w @ d.T - np.log(1 + x))
    step_w = learning_rate * (2 * main_part_derivative @ d)
    step_d = learning_rate * (2 * main_part_derivative.T @ w)

    w -= step_w
    d -= step_d

    return w, d

In [624]:
x = np.array(
    [
        [
            72,
            67,
            24,
            81,
            52,
            43,
            49,
            12,
            84,
            77,
            22,
            66,
            66,
            0,
            59,
            4,
            71,
            78,
            37,
            69,
            39,
            63,
            68,
            36,
            97,
            17,
        ],
        [
            72,
            38,
            34,
            43,
            11,
            46,
            91,
            96,
            43,
            4,
            80,
            77,
            19,
            18,
            39,
            8,
            43,
            58,
            59,
            43,
            40,
            55,
            14,
            96,
            90,
            43,
        ],
        [
            83,
            82,
            22,
            93,
            5,
            17,
            68,
            30,
            2,
            67,
            7,
            8,
            34,
            2,
            88,
            66,
            31,
            52,
            96,
            13,
            9,
            83,
            3,
            9,
            91,
            15,
        ],
        [
            73,
            25,
            40,
            82,
            42,
            30,
            79,
            77,
            15,
            76,
            65,
            6,
            7,
            44,
            98,
            88,
            65,
            74,
            33,
            48,
            61,
            54,
            64,
            28,
            49,
            38,
        ],
        [
            47,
            41,
            2,
            54,
            5,
            13,
            36,
            0,
            97,
            15,
            80,
            90,
            38,
            27,
            24,
            31,
            32,
            20,
            77,
            20,
            8,
            11,
            24,
            19,
            77,
            23,
        ],
        [
            55,
            26,
            42,
            5,
            98,
            87,
            36,
            1,
            11,
            19,
            57,
            68,
            92,
            49,
            98,
            9,
            98,
            24,
            0,
            13,
            14,
            90,
            10,
            51,
            30,
            30,
        ],
        [
            98,
            12,
            2,
            66,
            27,
            12,
            12,
            60,
            46,
            71,
            89,
            82,
            75,
            49,
            5,
            77,
            52,
            96,
            29,
            32,
            51,
            71,
            45,
            16,
            74,
            12,
        ],
        [
            18,
            92,
            95,
            62,
            51,
            65,
            1,
            49,
            51,
            62,
            16,
            64,
            97,
            48,
            78,
            14,
            90,
            50,
            43,
            49,
            59,
            11,
            75,
            50,
            60,
            2,
        ],
        [
            47,
            45,
            88,
            78,
            93,
            30,
            79,
            20,
            69,
            68,
            6,
            76,
            41,
            3,
            57,
            98,
            62,
            6,
            65,
            53,
            7,
            9,
            76,
            96,
            19,
            88,
        ],
        [
            45,
            73,
            39,
            70,
            21,
            62,
            82,
            13,
            14,
            72,
            8,
            23,
            99,
            49,
            33,
            80,
            21,
            67,
            37,
            31,
            38,
            48,
            40,
            61,
            61,
            67,
        ],
        [
            61,
            86,
            91,
            61,
            13,
            88,
            79,
            56,
            78,
            87,
            91,
            94,
            37,
            14,
            15,
            44,
            91,
            3,
            6,
            23,
            15,
            85,
            18,
            58,
            11,
            4,
        ],
        [
            50,
            28,
            55,
            44,
            21,
            62,
            98,
            64,
            85,
            84,
            4,
            31,
            59,
            16,
            51,
            11,
            37,
            44,
            6,
            60,
            47,
            54,
            70,
            29,
            32,
            74,
        ],
        [
            39,
            6,
            17,
            54,
            15,
            71,
            24,
            94,
            5,
            16,
            15,
            74,
            43,
            98,
            75,
            10,
            79,
            78,
            99,
            47,
            99,
            4,
            22,
            90,
            12,
            19,
        ],
        [
            74,
            51,
            67,
            72,
            21,
            9,
            57,
            50,
            0,
            43,
            80,
            91,
            58,
            46,
            92,
            98,
            11,
            4,
            36,
            31,
            90,
            90,
            91,
            52,
            68,
            63,
        ],
        [
            95,
            76,
            24,
            52,
            3,
            71,
            19,
            75,
            34,
            92,
            83,
            15,
            77,
            12,
            96,
            58,
            63,
            68,
            75,
            9,
            28,
            44,
            30,
            94,
            67,
            49,
        ],
        [
            22,
            93,
            33,
            77,
            2,
            9,
            3,
            3,
            47,
            56,
            84,
            70,
            15,
            81,
            16,
            49,
            20,
            95,
            18,
            22,
            98,
            3,
            77,
            27,
            1,
            13,
        ],
        [
            45,
            63,
            34,
            0,
            75,
            45,
            30,
            23,
            7,
            7,
            80,
            62,
            34,
            11,
            41,
            16,
            45,
            6,
            11,
            21,
            18,
            55,
            7,
            24,
            18,
            70,
        ],
        [
            13,
            7,
            21,
            85,
            29,
            53,
            56,
            83,
            63,
            89,
            18,
            67,
            93,
            73,
            37,
            3,
            55,
            65,
            16,
            72,
            6,
            80,
            0,
            39,
            51,
            24,
        ],
        [
            72,
            23,
            9,
            56,
            60,
            88,
            69,
            6,
            8,
            92,
            3,
            44,
            29,
            5,
            58,
            58,
            55,
            24,
            48,
            57,
            28,
            69,
            64,
            72,
            58,
            98,
        ],
        [
            20,
            56,
            52,
            74,
            27,
            95,
            85,
            20,
            7,
            52,
            8,
            93,
            76,
            53,
            62,
            54,
            34,
            25,
            89,
            38,
            85,
            29,
            38,
            18,
            1,
            28,
        ],
        [
            30,
            97,
            74,
            11,
            36,
            92,
            55,
            74,
            34,
            29,
            12,
            40,
            61,
            69,
            54,
            72,
            14,
            64,
            73,
            75,
            75,
            4,
            37,
            47,
            17,
            29,
        ],
        [
            7,
            18,
            25,
            6,
            51,
            63,
            63,
            53,
            38,
            96,
            19,
            56,
            36,
            35,
            75,
            99,
            32,
            28,
            68,
            14,
            55,
            9,
            3,
            19,
            9,
            59,
        ],
        [
            0,
            98,
            57,
            98,
            13,
            25,
            55,
            56,
            58,
            37,
            30,
            90,
            51,
            71,
            10,
            36,
            58,
            94,
            32,
            80,
            95,
            44,
            40,
            82,
            99,
            6,
        ],
        [
            74,
            28,
            93,
            37,
            81,
            54,
            92,
            89,
            52,
            96,
            93,
            8,
            65,
            82,
            7,
            14,
            75,
            0,
            45,
            59,
            15,
            17,
            85,
            87,
            10,
            52,
        ],
        [
            10,
            74,
            13,
            23,
            56,
            25,
            66,
            59,
            86,
            39,
            47,
            72,
            92,
            28,
            23,
            75,
            23,
            18,
            5,
            20,
            36,
            52,
            42,
            56,
            20,
            7,
        ],
        [
            32,
            37,
            58,
            20,
            3,
            33,
            76,
            92,
            36,
            73,
            90,
            53,
            82,
            78,
            6,
            66,
            11,
            33,
            64,
            68,
            51,
            76,
            94,
            94,
            74,
            88,
        ],
    ]
)
w = np.array(
    [
        [
            0.7236403458959406,
            0.0956019387576047,
            0.0025299248050427714,
            0.8219024304497274,
            0.43253754513562515,
            0.8013795226500925,
        ],
        [
            0.1645225418615962,
            0.17254764305062675,
            0.915834884927677,
            0.15659274788174238,
            0.4408801726853846,
            0.6712507398638423,
        ],
        [
            0.7220314060070252,
            0.1109087497279424,
            0.8673890374761482,
            0.6019681601593759,
            0.21136092547712715,
            0.46410460250177055,
        ],
        [
            0.2051472970020488,
            0.7021578939163269,
            0.4920315519905448,
            0.8786530949689468,
            0.8406582658875078,
            0.7656322995670249,
        ],
        [
            0.5314722945128192,
            0.20582039242966288,
            0.6649783801689887,
            0.9122470167268962,
            0.06046820688028054,
            0.7640361944809368,
        ],
        [
            0.8531299103217095,
            0.8837919293919477,
            0.5584731093192602,
            0.5488851769744959,
            0.5426259488733682,
            0.8101919492091457,
        ],
        [
            0.014691936047236509,
            0.8299297933323541,
            0.04420642840864686,
            0.19514486051010316,
            0.5605834763387445,
            0.021425480951998255,
        ],
        [
            0.6251450063221531,
            0.916013278510962,
            0.9266733043623226,
            0.4314909906070713,
            0.5861250222822415,
            0.6933275681854775,
        ],
        [
            0.613436662402963,
            0.25971014970117345,
            0.8516017571376222,
            0.3946078968050868,
            0.5030576607821642,
            0.4947379037657953,
        ],
        [
            0.1831704150579163,
            0.7027400367924451,
            0.9687380255252486,
            0.05161874874595729,
            0.5662001008903554,
            0.1163342848387866,
        ],
        [
            0.7871817922619593,
            0.2744881375377821,
            0.47927673745333677,
            0.29916960674176196,
            0.36165825794500894,
            0.4473132902029585,
        ],
        [
            0.9460136327515588,
            0.9784510345542305,
            0.7292583652396575,
            0.9967710630754123,
            0.5222338378761943,
            0.15774056366446398,
        ],
        [
            0.5663154377790205,
            0.31992559317458047,
            0.895903341426127,
            0.10800834113175917,
            0.7025174488794499,
            0.09983260287294515,
        ],
        [
            0.2870859802344229,
            0.6124244361792336,
            0.03043370710190285,
            0.4177754705816856,
            0.41076530192454186,
            0.059229404317664214,
        ],
        [
            0.453421549409623,
            0.10006035499361832,
            0.4729640823042591,
            0.4187735846604017,
            0.19252902582118436,
            0.4571615927038022,
        ],
        [
            0.4717366823003555,
            0.311470963714212,
            0.7563429074261462,
            0.9450429903711869,
            0.23851560864324461,
            0.4264206092799121,
        ],
        [
            0.14209483434392234,
            0.9545183136517666,
            0.02853067102355411,
            0.8397788414889452,
            0.28747164060068653,
            0.5890799959267197,
        ],
        [
            0.7137144457967627,
            0.7108041311984524,
            0.3391605131543025,
            0.18466650700703768,
            0.07037926283668172,
            0.1691030355977058,
        ],
        [
            0.4181167385409663,
            0.5733773938988352,
            0.9308794863064511,
            0.955104551017489,
            0.7472618752255964,
            0.9106883383537705,
        ],
        [
            0.29209827546854006,
            0.7950653331872178,
            0.9314779081831699,
            0.2137419943082265,
            0.9590688802321072,
            0.21779623076769017,
        ],
        [
            0.6414528631722118,
            0.7772400748403205,
            0.7240597746441493,
            0.4846785371953165,
            0.20903895145878393,
            0.9928711008461597,
        ],
        [
            0.4987552039133927,
            0.966261456826001,
            0.6392910461562884,
            0.3891694028095307,
            0.14376415691424704,
            0.5654942409405452,
        ],
        [
            0.39062876410463865,
            0.4372793328535669,
            0.9066881332880398,
            0.928194141998039,
            0.26891611788606773,
            0.970014111003586,
        ],
        [
            0.05753018657343756,
            0.5987554892139141,
            0.6695393400712614,
            0.4342378657370556,
            0.5068004463455815,
            0.28913437767829675,
        ],
        [
            0.31284712702847906,
            0.6696586256781413,
            0.6349611781499843,
            0.11008282689008553,
            0.9000387199581723,
            0.5893732652223279,
        ],
        [
            0.38771861901614457,
            0.9275236976874062,
            0.1507893346167909,
            0.2649576462980838,
            0.8917999241041804,
            0.7060665522096253,
        ],
    ]
)
d = np.array(
    [
        [
            0.7146175764456325,
            0.31161087332596693,
            0.799868898982844,
            0.3303984762074823,
            0.15755367025489198,
            0.8822561515814714,
        ],
        [
            0.48324415449065805,
            0.32294633607735035,
            0.273076894762348,
            0.46575965932905583,
            0.35173647464295466,
            0.0698782343365999,
        ],
        [
            0.05951092454514029,
            0.9631544906381114,
            0.14919875559361273,
            0.9071033838543416,
            0.9235221236014998,
            0.15343960980130578,
        ],
        [
            0.37667471994346735,
            0.3832592710693109,
            0.1372971042292026,
            0.5063394603470396,
            0.3657347277059969,
            0.21520394748123772,
        ],
        [
            0.5589413502705171,
            0.9228726685280682,
            0.9028006349689756,
            0.7902921185261006,
            0.09337560160131464,
            0.8806823905125992,
        ],
        [
            0.19078196327854235,
            0.9862705503057667,
            0.00800242331367751,
            0.7036641885324555,
            0.7452071471082209,
            0.85314563397203,
        ],
        [
            0.42059808696733225,
            0.3678976649279547,
            0.15153142787888962,
            0.9831212856723789,
            0.7218055186807681,
            0.8943329971799489,
        ],
        [
            0.058672278269596756,
            0.6364681756816949,
            0.24610924719747507,
            0.8429515887557353,
            0.23639035927773622,
            0.9193123017124043,
        ],
        [
            0.3667295853360063,
            0.46010540148263646,
            0.818107188288508,
            0.027140526241385965,
            0.4420026102807323,
            0.3050634480740779,
        ],
        [
            0.9602073143407148,
            0.5408373879572825,
            0.4027285042008486,
            0.854769594232319,
            0.8977332204421882,
            0.7804511190784789,
        ],
        [
            0.9554030213710992,
            0.6286064807931032,
            0.7899715293283952,
            0.20778805629585584,
            0.34452317136784105,
            0.8373278109724016,
        ],
        [
            0.9511367017094053,
            0.8108673965379353,
            0.5917802839407773,
            0.08638924272725734,
            0.5614389008614823,
            0.10285577516634681,
        ],
        [
            0.15293610366355237,
            0.4726630546010566,
            0.7151593451811216,
            0.6787398883364194,
            0.05726564336395312,
            0.11175850750236216,
        ],
        [
            0.5284613368117816,
            0.2171952870224766,
            0.14730305464381344,
            0.16327154211081985,
            0.22473713798444594,
            0.8780618686814565,
        ],
        [
            0.6229749344457463,
            0.2450307938022901,
            0.856716114606889,
            0.5130970556519491,
            0.09638792050417233,
            0.5580480219996736,
        ],
        [
            0.02557257524793899,
            0.16614696997999867,
            0.9057474930205891,
            0.9639638373151679,
            0.8305505098688646,
            0.13212730388642224,
        ],
        [
            0.9945224728362285,
            0.601015567237635,
            0.627777689771871,
            0.062014306890884385,
            0.5482657713187832,
            0.050645865034282034,
        ],
        [
            0.222530564647055,
            0.16270913407631815,
            0.3463743065572499,
            0.3642479732760492,
            0.6787809827842912,
            0.6698646733332234,
        ],
        [
            0.0514618465561959,
            0.09146560753484756,
            0.5663782403169225,
            0.09809277695721963,
            0.7435268283749681,
            0.6941669527997202,
        ],
        [
            0.7220783710745816,
            0.242189075941737,
            0.19197963437514165,
            0.2789768605860322,
            0.1257100212184865,
            0.25803379668907667,
        ],
        [
            0.5231678066470311,
            0.4611093289035184,
            0.8420569136872692,
            0.9566490894261072,
            0.07691192438283945,
            0.37613366065780873,
        ],
        [
            0.010481514678729265,
            0.5145103851754453,
            0.9425491945781952,
            0.24440293940943314,
            0.2766636476384883,
            0.9944680222564074,
        ],
        [
            0.7081598239606982,
            0.3291415847107684,
            0.7986116830970068,
            0.32005951294163504,
            0.988878016430301,
            0.16702718654180948,
        ],
        [
            0.3281899603978248,
            0.8371583970360936,
            0.914781121298489,
            0.9898376984561366,
            0.2605393835200668,
            0.7046307961318979,
        ],
        [
            0.6669241697435273,
            0.7506837943872975,
            0.3223310011040579,
            0.8024412673323509,
            0.47139557621217376,
            0.34991596973647043,
        ],
        [
            0.6981065171211724,
            0.7907802796637423,
            0.05700463849852977,
            0.29301210116680565,
            0.3246921756526622,
            0.9147896908728982,
        ],
    ]
)

In [625]:
w, d = update_glove_weights(x, w, d, 0.5878728651551414, 97, 0.8066056621137515)

In [626]:
w_a = [
    [
        37.894949622265095,
        35.32608453541813,
        39.58755103784916,
        33.56742861491748,
        31.889707572091282,
        33.00715086688418,
    ],
    [
        33.84381119427707,
        38.036945787753176,
        34.69008748288071,
        36.55796979484229,
        30.50706391340045,
        40.418661343576375,
    ],
    [
        22.292630735670524,
        22.26123383895233,
        27.281358895160945,
        29.762116352667036,
        26.093786102891865,
        28.873616599806525,
    ],
    [
        28.627244763428482,
        25.010240783610097,
        30.24910190951392,
        29.94251520200709,
        25.793350857840288,
        31.500830971429988,
    ],
    [
        24.86914949950388,
        21.18902354047288,
        25.17742133333728,
        16.734270195198985,
        21.32882188455089,
        20.133684294288226,
    ],
    [
        20.607435095197772,
        23.38444651793354,
        24.170604089788068,
        19.190985999049406,
        14.732626697936645,
        22.085786623565003,
    ],
    [
        41.83248643873107,
        39.23972477624278,
        48.25206831443954,
        39.18447489715208,
        36.223555643455576,
        43.838987581052415,
    ],
    [
        25.957662536809675,
        27.78414465308494,
        26.831852326280252,
        27.670481277635595,
        24.681666254723726,
        21.151581083066553,
    ],
    [
        32.851892878431265,
        36.440775336903165,
        35.75352506946652,
        37.30323390355503,
        34.53270452844875,
        31.563665533166038,
    ],
    [
        30.835038636245486,
        34.52475728758245,
        33.16493620256741,
        41.27828441760409,
        33.625782386520584,
        37.29511586329691,
    ],
    [
        33.83703170479853,
        40.43002001773247,
        36.184443329518004,
        37.12624122167327,
        35.08771179980609,
        35.139255640003434,
    ],
    [
        21.254977902045137,
        23.060454293882493,
        21.303095809926,
        23.591298842309634,
        20.611990194926875,
        24.660270597555062,
    ],
    [
        29.841626803079436,
        31.839284837338443,
        32.251032295284844,
        33.58647928947483,
        26.771111486581926,
        37.72191514982901,
    ],
    [
        44.55346815719802,
        45.05766658559051,
        52.471710160526506,
        49.79668060973384,
        39.22443377310308,
        47.134158002890395,
    ],
    [
        39.17357712318688,
        41.7702595729068,
        45.89022396002535,
        45.0492861943433,
        37.91715762872248,
        46.34889394407784,
    ],
    [
        27.66000266779389,
        21.58602345742282,
        26.20679184375472,
        22.50135429223267,
        24.80834623255724,
        21.774444708488712,
    ],
    [
        24.235631356398315,
        25.330637975611467,
        24.35882534462737,
        19.614644539141796,
        16.096340205031563,
        23.969425116432753,
    ],
    [
        32.714784659285534,
        37.00623479448081,
        32.80393996121548,
        35.850620589216966,
        29.775563785011812,
        40.1682995737016,
    ],
    [
        22.473439079254767,
        21.63905155108375,
        21.309990260094732,
        21.384623495336577,
        20.03459479396905,
        23.052219798176157,
    ],
    [
        25.833857096031426,
        28.006935843177743,
        27.529323359718862,
        31.562166519759444,
        27.60438145889944,
        28.783954098290046,
    ],
    [
        21.724755945837586,
        25.46457829943065,
        25.453384408355255,
        30.825724432379566,
        26.016894561298137,
        26.62071479383954,
    ],
    [
        21.81725003511163,
        24.05930480422645,
        25.802129103735567,
        28.29934558071534,
        24.882668372989844,
        27.55650443868176,
    ],
    [
        29.608735571248964,
        29.25708252569204,
        28.139612290451367,
        30.385015579072235,
        27.487966113975027,
        24.962848111956525,
    ],
    [
        39.403771170848984,
        41.917746619056196,
        37.84557861173111,
        40.69821075997771,
        35.62873141884051,
        45.079155106826896,
    ],
    [
        23.04120185595438,
        26.786780888021532,
        31.2639541627983,
        29.00755536709279,
        22.203982913169703,
        24.495657931209138,
    ],
    [
        32.403129266441965,
        35.840017821814904,
        37.17203998680061,
        38.03281061366912,
        32.78714804739385,
        39.28183659620253,
    ],
]
d_a = [
    [
        25.737584343831532,
        31.976583863756073,
        31.991796072794237,
        30.397770296625676,
        29.155071762578395,
        28.039906267294363,
    ],
    [
        39.18090234664077,
        43.45477045037197,
        51.786860206912635,
        43.09517606909757,
        36.467098828996626,
        45.87756004198227,
    ],
    [
        25.861638022497715,
        30.92652100750488,
        33.60725677532385,
        25.184249360983333,
        25.33461010279592,
        26.93195963760828,
    ],
    [
        40.44255724281372,
        42.96387247113741,
        54.334811744166686,
        43.93191831350745,
        40.27560778331666,
        40.23238565434164,
    ],
    [
        12.60992409857056,
        19.726007962924154,
        16.071644794959607,
        14.161701223195637,
        15.198176415121148,
        15.276991028259971,
    ],
    [
        29.04588467665483,
        33.80505200039617,
        38.04855856282448,
        24.880963308310946,
        27.92967131290679,
        28.76950184954046,
    ],
    [
        30.237437485971547,
        38.66537897520929,
        42.70573743205196,
        29.776069477137508,
        33.905932759342974,
        31.963285050639865,
    ],
    [
        28.028933057301142,
        38.08684708617139,
        37.544542228274906,
        25.321920916303146,
        31.533875714831566,
        29.298469894937988,
    ],
    [
        30.71424884632851,
        30.95498799285244,
        34.15309776359906,
        31.217888494294368,
        24.993625601516545,
        31.928661444788062,
    ],
    [
        26.146184710215547,
        32.46478475669627,
        31.62669482728407,
        25.64499465851579,
        26.152486540238574,
        24.145438152077382,
    ],
    [
        18.513444310864926,
        28.714219341908812,
        22.747896951556424,
        25.100699161335342,
        23.409036420228585,
        24.514405065630598,
    ],
    [
        34.97755034206531,
        41.735458041689185,
        42.23493985650944,
        36.009265517304414,
        34.352119472664846,
        37.644071472340144,
    ],
    [
        40.092702168388236,
        53.62113178600554,
        48.84356458863764,
        36.39043705817393,
        42.28637968655391,
        40.280349955856444,
    ],
    [
        27.09610113742262,
        41.26880485353593,
        39.346977655994735,
        27.591514389672355,
        31.995616644580924,
        28.87474376777405,
    ],
    [
        32.46320730114356,
        36.48181104762826,
        38.714631286168924,
        32.33019087297395,
        30.27052570417721,
        32.240695698900474,
    ],
    [
        24.62678794161355,
        36.21315596132247,
        34.91748692038032,
        27.31267275243153,
        29.942035228767967,
        28.65251738118413,
    ],
    [
        31.908551549750985,
        35.150861854275824,
        38.81804750019257,
        33.07677850315476,
        30.681610274970396,
        34.06666656257596,
    ],
    [
        30.48008982537277,
        35.58865452115646,
        40.57350768906002,
        33.041098811455335,
        29.65449752756178,
        32.70055131481196,
    ],
    [
        30.57739739495991,
        34.19410957043535,
        44.6772377886788,
        31.000103986208202,
        31.17905081422838,
        34.40131935305819,
    ],
    [
        32.07230397160679,
        40.95180323381044,
        41.63975849206306,
        34.856807118616764,
        34.237546314068695,
        36.52332793161486,
    ],
    [
        22.165179600850514,
        32.94179603315074,
        32.363977866166756,
        25.282871906957595,
        28.351785486262752,
        24.632822374929177,
    ],
    [
        28.16174839655334,
        35.544107679226634,
        28.118950582145626,
        28.3788129267007,
        28.77498009421306,
        28.156490872467124,
    ],
    [
        22.230950864818148,
        30.475956010093874,
        28.654962541126377,
        28.12410255465114,
        28.61378026798125,
        25.497683394160806,
    ],
    [
        22.53316383674741,
        26.507404577732675,
        32.26317495605627,
        20.718985703375136,
        26.011612018642545,
        25.31060496066265,
    ],
    [
        24.59110367631163,
        27.543472573592013,
        30.621636332862288,
        28.235921081852787,
        24.45902884286537,
        29.64689202372981,
    ],
    [
        21.024604742056106,
        30.902045272493908,
        29.293039097395038,
        24.61451220740829,
        24.812533834540638,
        23.893380083955623,
    ],
]

In [627]:
d_a == d

array([[ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True, False,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  T

In [628]:
w_a == w

array([[ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  T

### [Напишите функцию, которая для заданного слова получит список наиболее похожих на него слов вместе с оценкой близости.](https://stepik.org/lesson/261476/step/8?unit=242225)

In [833]:
def get_nearest(embeddings, query_word_id, get_n):
    """
    embeddings - VocabSize x EmbSize - word embeddings
    query_word_id - integer - id of query word to find most similar to
    get_n - integer - number of most similar words to retrieve

    returns list of `get_n` tuples (word_id, similarity) sorted by descending order of similarity value
    """

    def get_proximity_score(v1, v2):
        v1 = v1 / np.linalg.norm(v1)
        v2 = v2 / np.linalg.norm(v2)
        return -np.linalg.norm(v1 - v2)

    result = []

    for i, emb in enumerate(embeddings):
        result.append((i, get_proximity_score(embeddings[query_word_id], emb)))

    result.sort(key=lambda item: item[1], reverse=True)
    result = result[:get_n]

    return result

In [834]:
r = get_nearest(
    np.array(
        [
            [
                0.7299015792584768,
                0.2915364327741303,
                0.5307571134639943,
                0.3101345732086396,
                0.8327085262119636,
                0.39018382511314353,
                0.678094726221033,
                0.12372148102696612,
                0.5966533433209616,
            ],
            [
                0.5411155947267721,
                0.046791742239819856,
                0.5358832195593092,
                0.09894162419462038,
                0.6350557173679914,
                0.15126161842015717,
                0.11375720216711405,
                0.46954553941325416,
                0.8281402097264261,
            ],
            [
                0.5323869209381028,
                0.2005012376766715,
                0.5925043884236925,
                0.4621530177251649,
                0.3886830034303448,
                0.6403738184472031,
                0.23320289120963578,
                0.43574647265888766,
                0.5305633832484254,
            ],
        ]
    ),
    0,
    8,
)

In [835]:
r

[(0, -0.0), (2, -0.5028921892757982), (1, -0.5422310219843909)]