!pip install torch==2.0.1 torchtext==0.15.2
!pip install 'portalocker>=2.0.0'

In [None]:
import torch
import torch.nn as nn
from torchtext import datasets
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader
from functools import partial
import random


In [20]:
# Device configuration
if torch.cuda.is_available():
    print("hi")
    device = torch.device(type='cuda', index=0)
else:
    device = torch.device(type='cpu', index=0)

In [21]:
# Load dataset
train_data = datasets.AG_NEWS(split='train')
test_data = datasets.AG_NEWS(split='test')
req_train_data = [text for _, text in train_data]
req_test_data = [text for _, text in test_data]

In [22]:
# Tokenization and vocabulary building
tokenizer = get_tokenizer("basic_english", language="en")
min_word_freq = 15

def build_vocab(req_train_data, tokenizer):
    vocab = build_vocab_from_iterator(
        map(tokenizer, req_train_data),
        specials=["<unk>"],
        min_freq=min_word_freq
    )
    vocab.set_default_index(vocab["<unk>"])
    return vocab

vocab = build_vocab(req_train_data, tokenizer)
vocab_size = len(vocab)
window_size = 4
max_norm = 1
embed_dim = 300
batch_size = 16
num_neg_samples = 3
text_pipeline = lambda x: [vocab[token] for token in tokenizer(x)]

In [23]:
# Make skipgram input function
def make_skipgram_input(batch, text_pipeline, num_neg_samples):
    batch_input_word, batch_target_words, batch_labels = [], [], []
    
    for text in batch:
        text_tokens = text_pipeline(text)
        
        if len(text_tokens) < (window_size * 2) + 1:
            continue
            
        current_words = set(text_tokens)
        negative_samples = [idx for idx in range(vocab_size) if idx not in current_words]
        
        for i in range(len(text_tokens) - window_size * 2):
            input_word = text_tokens[i + window_size]
            
            for j in range(i, i + window_size):
                batch_input_word.append(input_word)
                batch_target_words.append(text_tokens[j])
                batch_labels.append(1)
                
            for j in range(i + window_size + 1, i + (2 * window_size) + 1):
                batch_input_word.append(input_word)
                batch_target_words.append(text_tokens[j])
                batch_labels.append(1)
                
            for _ in range(num_neg_samples):
                batch_input_word.append(input_word)
                batch_target_words.append(random.choice(negative_samples))
                batch_labels.append(0)
                
    return torch.tensor(batch_input_word), torch.tensor(batch_target_words), torch.tensor(batch_labels)

In [24]:
# DataLoader setup
train_skipgram = DataLoader(
    req_train_data,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=partial(make_skipgram_input, text_pipeline=text_pipeline, num_neg_samples=num_neg_samples)
)

test_skipgram = DataLoader(
    req_test_data,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=partial(make_skipgram_input, text_pipeline=text_pipeline, num_neg_samples=num_neg_samples)
)

In [25]:
#Define model
class NegSkipGram(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super().__init__()
        self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_dim, max_norm=max_norm)
        self.linear = nn.Linear(in_features=embed_dim, out_features=vocab_size)
    
    def forward(self, input_words, target_words):
        input_embeds = self.embeddings(input_words)
        target_embeds = self.embeddings(target_words)
        return torch.sum(input_embeds * target_embeds, dim=1)


In [26]:
# Training function
def train_one_epoch(model, dataloader, opt):
    model.train()
    running_loss = 0.0
    
    for i, (inputs, target, labels) in enumerate(dataloader):
        inputs = inputs.to(device)
        target = target.to(device)
        labels = labels.float().to(device)
        
        opt.zero_grad()
        
        logits = model(inputs, target)
        
        # Calculate custom loss using logsigmoid
        loss = -torch.mean(torch.nn.functional.logsigmoid(logits) * labels + 
                           torch.nn.functional.logsigmoid(-logits) * (1 - labels))
        running_loss += loss.item()

        loss.backward()
        opt.step()

    average_loss = running_loss / len(dataloader)
    print(f'Average Loss: {average_loss:.4f}')

In [27]:
# Setup for training
n_epochs = 10
model = NegSkipGram(vocab_size, embed_dim).to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.001)




In [28]:
# Training loop
for e in range(n_epochs):
    print(f"Epoch {e + 1}/{n_epochs}")
    train_one_epoch(model, train_skipgram, opt)

Epoch 1/10


KeyboardInterrupt: 