This notebook is made for learning purposes and is inspired from work at https://www.kaggle.com/code/priyankdl/word2vec-skipgraom-cbow

In [None]:
!pip install torch==2.0.1 torchtext==0.15.2
!pip install portalocker>=2.0.0



Importing the dependencies

In [None]:
#for the model
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch import optim

#For the dataset part
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext import datasets
from functools import partial

import numpy as np

The data part

In [None]:
if torch.cuda.is_available():
  device=torch.device(type='cuda',index=0)
else:
  device=torch.device(type='cpu',index=0)

In [None]:
train_data=datasets.IMDB(split="train")
test_data=datasets.IMDB(split="test")

train_reviews=[]
test_reviews=[]
for label,review in train_data:
  train_reviews.append(review)
  if (len(train_reviews)>1000):
    break

for label,review in test_data:
  test_reviews.append(review)
  if (len(test_reviews)>1000):
    break

#now we have got the datasets

tokenizer=get_tokenizer("basic_english","en")
vocab=build_vocab_from_iterator(
    map(tokenizer,train_reviews),
    specials=["<unk>"],
    special_first=False,
    min_freq=20
)
vocab.set_default_index(vocab["<unk>"])

def colate_function(batch,text_pipeline):
  input=[]
  ground_truth=[]
  for review in batch:
    indices=vocab.lookup_indices(tokenizer(review))
    if (len(indices)<9):
      continue
    else:
      for idx in range(len(indices)-8):
        context_window=indices[idx:idx+9]
        for i in range(8):
          input.append(context_window[4])
        for i in range(9):
          if (i==4):
            continue
          else:
            ground_truth.append(context_window[i])

  input=torch.tensor(input,dtype=torch.long)
  ground_truth=torch.tensor(ground_truth,dtype=torch.long)

  return input,ground_truth

def text_pipeline(review):
  return vocab.lookup_indices(tokenizer(review))


train_dataloader=DataLoader(
    train_reviews,
    batch_size=64,
    shuffle=True,
    collate_fn=partial(colate_function,text_pipeline=text_pipeline)
)

test_dataloader=DataLoader(
    test_reviews,
    batch_size=64,
    shuffle=True,
    collate_fn=partial(colate_function,text_pipeline=text_pipeline)
)

In [None]:
class SkipGram(nn.Module):
  def __init__(self,vocab_size):
    super().__init__()
    self.ebd1=nn.Embedding(
        num_embeddings=vocab_size,
        embedding_dim=300,
        max_norm=1
    )

    self.linear1=nn.Linear(
        in_features=300,
        out_features=vocab_size
    )

  def forward(self,x):
    x=self.ebd1(x)
    x=self.linear1(x)

    return x

model=SkipGram(vocab.__len__()).to(device)
optimiser=optim.Adam(model.parameters(),lr=0.001)
loss_function=nn.CrossEntropyLoss()


def train_one_epoch(model,dataloader,optimiser,loss_function):
  model.train()
  running_loss=[]

  for i,batch_data in enumerate(dataloader):
    input=batch_data[0].to(device)
    target=batch_data[1].to(device)
    output=model(input)
    loss=loss_function(output,target)
    running_loss.append(loss.item())
    optimiser.zero_grad()
    loss.backward()
    optimiser.step()

  print("The loss at the end of this epoch is ",np.mean(running_loss))

def eval_one_epoch(model,dataloader,loss_function):
  model.eval()
  running_loss=[]

  for i,batch_data in enumerate(dataloader):
    input=batch_data[0].to(device)
    target=batch_data[1].to(device)
    output=model(input)
    loss=loss_function(output,target)
    running_loss.append(loss.item())

  print("The loss at the end of this epoch is ",np.mean(running_loss))


for i in range(5):
  print("EPOCH number:",i+1)
  train_one_epoch(model,train_dataloader,optimiser,loss_function)
  eval_one_epoch(model,test_dataloader,loss_function)




EPOCH number: 1
The loss at the end of this epoch is  7.040234625339508
The loss at the end of this epoch is  6.954651653766632
EPOCH number: 2
The loss at the end of this epoch is  6.8429993987083435
The loss at the end of this epoch is  6.663528472185135
EPOCH number: 3
The loss at the end of this epoch is  6.490560740232468
The loss at the end of this epoch is  6.2429671585559845
EPOCH number: 4
The loss at the end of this epoch is  6.070838272571564
The loss at the end of this epoch is  5.825994580984116
EPOCH number: 5
The loss at the end of this epoch is  5.7057976722717285
The loss at the end of this epoch is  5.502445876598358
