In [25]:
# !pip install datasets
import torch    
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
from word2vec.data_setup import load_skipgram_data, create_data_loaders
from word2vec.model import SkipGramNegativeSampling, NegativeSampleLoss
from word2vec.utils import save_embeddings, load_embeddings, visualize_embeddings

In [26]:
# hyperparameters 
vocab_size = 1000
embedding_dim = 100
context_size = 3

epochs = 3
batch_size = 32

# device agnostic code
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [27]:
dataset, noise_dist = load_skipgram_data(vocab_size, context_size, amount_of_articles=200)
train_dataloader = create_data_loaders(dataset, batch_size=batch_size)

[nltk_data] Downloading package punkt to /Users/aspisov/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
Found cached dataset wikipedia (/Users/aspisov/.cache/huggingface/datasets/wikipedia/20220301.simple/2.0.0/d41137e149b2ea90eead07e7e3f805119a8c22dd1d5b61651af8e3e3ee736001)


  0%|          | 0/1 [00:00<?, ?it/s]

Total tokens: 0.15M


In [28]:
model = SkipGramNegativeSampling(vocab_size, embedding_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_function = NegativeSampleLoss()

model

number of parameters: 0.20M


SkipGramNegativeSampling(
  (input_embeddings): Embedding(1000, 100)
  (output_embeddings): Embedding(1000, 100)
)

In [29]:

def generate_negative_samples(n_samples, noise_dist, batch_size=batch_size):            
    noise_words = torch.multinomial(input       = noise_dist,           # input tensor containing probabilities
                                        num_samples = batch_size * n_samples, # number of samples to draw
                                        replacement = True)
    return noise_words.view(batch_size, n_samples)

In [30]:
# Total number of batches
total_batches = len(train_dataloader)
# Interval for reporting loss
report_interval = total_batches // 5

for epoch in tqdm(range(epochs)):
    model.train()
    train_loss = 0
    for batch_idx, (target, context) in enumerate(train_dataloader):
        negative_samples = generate_negative_samples(n_samples=5, noise_dist=noise_dist, batch_size=target.shape[0])
        target, context, negative_samples = target.to(device), context.to(device), negative_samples.to(device)

        embedded_center = model.forward_input(target)
        embedded_context = model.forward_output(context)
        embedded_noise = model.forward_noise(negative_samples)

        loss = loss_function(embedded_center, embedded_context, embedded_noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        
        # Report loss at regular intervals
        if (batch_idx + 1) % report_interval == 0:
            avg_loss = train_loss / (batch_idx + 1)
            print(f"\tEpoch {epoch + 1}, batches processed: {(batch_idx + 1) / total_batches * 100:.1f}%, Loss: {avg_loss:.2f}")

    # Final loss report for the epoch
    avg_loss = train_loss / total_batches
    print(f"Epoch {epoch + 1} Average Loss: {avg_loss:.2f}")

    # Save embeddings every epoch
    save_embeddings(model, dataset.word_to_idx, 'data/word_embeddings.txt')


  0%|          | 0/3 [00:00<?, ?it/s]

	Epoch 1, batches processed: 20.0%, Loss: 4.47
	Epoch 1, batches processed: 40.0%, Loss: 3.55
	Epoch 1, batches processed: 60.0%, Loss: 3.14
	Epoch 1, batches processed: 80.0%, Loss: 2.92
	Epoch 1, batches processed: 100.0%, Loss: 2.77
Epoch 1 Average Loss: 2.77
	Epoch 2, batches processed: 20.0%, Loss: 2.12
	Epoch 2, batches processed: 40.0%, Loss: 2.12
	Epoch 2, batches processed: 60.0%, Loss: 2.12
	Epoch 2, batches processed: 80.0%, Loss: 2.11
	Epoch 2, batches processed: 100.0%, Loss: 2.11
Epoch 2 Average Loss: 2.11
	Epoch 3, batches processed: 20.0%, Loss: 2.04
	Epoch 3, batches processed: 40.0%, Loss: 2.05
	Epoch 3, batches processed: 60.0%, Loss: 2.05
	Epoch 3, batches processed: 80.0%, Loss: 2.05
	Epoch 3, batches processed: 100.0%, Loss: 2.05
Epoch 3 Average Loss: 2.05
