In [None]:
# !pip install datasets
import torch    
from tqdm.auto import tqdm
from word2vec.data_setup import load_skipgram_data, create_data_loaders, generate_negative_samples
from word2vec.model import SkipGramNegativeSampling, NegativeSampleLoss
from word2vec.utils import save_embeddings, save_checkpoint

In [None]:
# hyperparameters 
vocab_size = 20000
embedding_dim = 300
context_size = 5

epochs = 3
batch_size = 128

# device agnostic code
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
dataset, noise_dist = load_skipgram_data(vocab_size, context_size, amount_of_articles=200)
train_dataloader = create_data_loaders(dataset, batch_size=batch_size)

In [None]:
model = SkipGramNegativeSampling(vocab_size, embedding_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_function = NegativeSampleLoss()
model

In [None]:
# Total number of batches
total_batches = len(train_dataloader)
# Interval for reporting loss
report_interval = total_batches // 5

for epoch in tqdm(range(epochs)):
    model.train()
    train_loss = 0
    for batch_idx, (target, context) in enumerate(train_dataloader):
        negative_samples = generate_negative_samples(n_samples=5, noise_dist=noise_dist, batch_size=target.shape[0])
        target, context, negative_samples = target.to(device), context.to(device), negative_samples.to(device)

        embedded_center = model.forward_input(target)
        embedded_context = model.forward_output(context)
        embedded_noise = model.forward_noise(negative_samples)

        loss = loss_function(embedded_center, embedded_context, embedded_noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        
        # Report loss at regular intervals
        if (batch_idx + 1) % report_interval == 0:
            avg_loss = train_loss / (batch_idx + 1)
            print(f"\tEpoch {epoch + 1}, batches processed: {(batch_idx + 1) / total_batches * 100:.1f}%, Loss: {avg_loss:.2f}")
            
            # Save model and embeddings
            save_checkpoint(model, 'skipgram.pt')
            save_embeddings(model, dataset.word_to_idx, 'word_embeddings.txt')
            
    # Final loss report for the epoch
    avg_loss = train_loss / total_batches
    print(f"Epoch {epoch + 1} Average Loss: {avg_loss:.2f}")
