<a href="https://colab.research.google.com/github/EtienneFerrandi/Word-embedding/blob/main/Word_embedding_CBW.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import string

In [None]:
with open("AU_s_9.txt", "r") as raw_text:
    raw_text = raw_text.read().split()
raw_text=[''.join(c for c in s if c not in string.punctuation) for s in raw_text]  #on enlève la ponctuation de la liste de vocabulaire
raw_text

In [66]:
vocab = set(raw_text)
vocab_size = len(vocab)

In [67]:
word_to_ix = {word: i for i, word in enumerate(vocab)}
ix_to_word = {i: word for i, word in enumerate(vocab)}

In [68]:
CONTEXT_SIZE = 2  #on créé un environnement de deux mots à droite, deux mots à gauche, comme il convient pour une approche CBW

data = []
for i in range(2, len(raw_text) - 2):
    context = [raw_text[i - 2], raw_text[i - 1],
               raw_text[i + 1], raw_text[i + 2]]
    target = raw_text[i]
    data.append((context, target))
print(data[:5])

[(['Tractatus', 'sancti', 'de', 'decem'], 'Augustini'), (['sancti', 'Augustini', 'decem', 'chordis'], 'de'), (['Augustini', 'de', 'chordis', 'sermo'], 'decem'), (['de', 'decem', 'sermo', 'habitus'], 'chordis'), (['decem', 'chordis', 'habitus', 'Chusa'], 'sermo')]


In [69]:
class CBOW(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(CBOW, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.linear1 = nn.Linear(embedding_dim, 128)
        self.linear2 = nn.Linear(128, vocab_size) #définition de l'hyperparamètre d'apprentissage

    def forward(self, inputs):
        embeds = self.embeddings(inputs)
        embeds = torch.sum(embeds, dim=0).view(1,-1)
        out = F.relu(self.linear1(embeds))
        out = self.linear2(out)
        log_probs = F.log_softmax(out, dim=1)
        return log_probs #définition de la propagation avant et calcul de la couche de sortie

In [70]:
def make_context_vector(context, word_to_ix):
    idxs = [word_to_ix[w] for w in context]
    return torch.tensor(idxs, dtype=torch.long)

In [71]:
losses = []
loss_function = nn.NLLLoss()
model = CBOW(vocab_size, embedding_dim=20)
optimizer = optim.SGD(model.parameters(), lr=0.001) #calcul du descente de gradient stochastique

for epoch in range(100):
    total_loss = 0
    for context, target in data:
        context_idxs = make_context_vector(context, word_to_ix)
        model.zero_grad()
        log_probs = model(context_idxs)
        loss = loss_function(log_probs, torch.tensor([word_to_ix[target]], dtype=torch.long))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    losses.append(total_loss)  #calcul de la fonction de perte

print(losses) 

[53293.108147382736, 50253.469131708145, 48751.78603184223, 47672.16461801529, 46768.78481078148, 45960.62044394016, 45203.96868211031, 44470.8713760376, 43743.13012647629, 43007.847942113876, 42257.127690792084, 41483.61944037676, 40683.2115137279, 39853.445836126804, 38991.53594198823, 38096.115712314844, 37167.10522556305, 36204.85303398967, 35211.43820346892, 34189.33372756094, 33141.734033979475, 32072.105412974954, 30986.086135480553, 29888.992124021053, 28787.422549234703, 27688.314651513472, 26598.935599895194, 25526.199487818405, 24477.362013080157, 23459.277985659428, 22477.484794866294, 21536.75844200235, 20640.294699640945, 19790.20813646447, 18987.552151189186, 18231.745167727582, 17521.761984748766, 16855.904907517135, 16231.31908501871, 15645.655530881137, 15095.758815924637, 14579.044319703244, 14093.003472141456, 13635.125820330344, 13202.936497669434, 12794.77383004874, 12408.665450908476, 12042.804421952693, 11695.461307525402, 11366.054533780552, 11053.057803976117,

In [72]:
# Test
with torch.no_grad():
    context = ['tunc', 'dies', 'diceretur', 'hodiernus'] # le mot cible est 'qui'
    context_vector = make_context_vector(context, word_to_ix)
    predict = model(context_vector)
    print(ix_to_word[predict.argmax(dim=1).item()])

qui
