# Maxence Lasbordes | MASH

In [114]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from tqdm import tqdm
from transformers import BertTokenizer

In [115]:
# Loading the dataset
dataset = load_dataset("scikit-learn/imdb", split="train")

print(dataset)

Dataset({
    features: ['review', 'sentiment'],
    num_rows: 50000
})


In [116]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)

In [117]:
def preprocessing_fn(x, tokenizer):
    x["review_ids"] = tokenizer(
        x["review"],
        add_special_tokens=False,
        truncation=True,
        max_length=256,
        padding=False,
        return_attention_mask=False,
    )["input_ids"]
    x["label"] = 0 if x["sentiment"] == "negative" else 1
    return x

In [118]:
n_samples = 5000  # the number of training example

# We first shuffle the data !
dataset = dataset.shuffle(seed=42)

# Select 5000 samples
dataset = dataset.select(range(n_samples))

# Tokenize the dataset
dataset = dataset.map(lambda x: preprocessing_fn(x, tokenizer))

# Remove useless columns
dataset = dataset.remove_columns(["review", "sentiment"])

# Split the train and validation
split = dataset.train_test_split(test_size=0.1, seed=42)

document_train_set = split["train"]
document_valid_set = split["test"]

In [119]:
def extract_words_contexts(text_ids, R):
    words = []
    contexts = []
    for i in range(len(text_ids)):
        words.append(text_ids[i])
        context = []
        for j in range(max(0, i - R), min(len(text_ids), i + R + 1)):
            if i != j:
                context.append(text_ids[j])
        if (
            len(context) < 2 * R
        ):  # Adding a padding token of id 0 when the context is less than 2R
            context += [0] * (2 * R - len(context))
        contexts.append(context)
    return words, contexts


words, contexts = extract_words_contexts([1, 2, 3, 4, 5], 2)
print(contexts)

[[2, 3, 0, 0], [1, 3, 4, 0], [1, 2, 4, 5], [2, 3, 5, 0], [3, 4, 0, 0]]


In [120]:
def flatten_dataset_to_lists(dataset, R):
    words = []
    contexts = []
    for example in dataset:
        w, c = extract_words_contexts(example["review_ids"], R)
        words.extend(w)
        contexts.extend(c)
    return contexts, words

In [121]:
# Define the context window size R
R = 10

train_contexts, train_words = flatten_dataset_to_lists(document_train_set, R)
val_contexts, val_words = flatten_dataset_to_lists(document_valid_set, R)

print(f"List of words: {train_words[:10]}\nList of contexts: {train_contexts[:10]}")

List of words: [2054, 2057, 2031, 2182, 2003, 1037, 2143, 3819, 2005, 3087]
List of contexts: [[2057, 2031, 2182, 2003, 1037, 2143, 3819, 2005, 3087, 2008, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [2054, 2031, 2182, 2003, 1037, 2143, 3819, 2005, 3087, 2008, 17257, 0, 0, 0, 0, 0, 0, 0, 0, 0], [2054, 2057, 2182, 2003, 1037, 2143, 3819, 2005, 3087, 2008, 17257, 1999, 0, 0, 0, 0, 0, 0, 0, 0], [2054, 2057, 2031, 2003, 1037, 2143, 3819, 2005, 3087, 2008, 17257, 1999, 1996, 0, 0, 0, 0, 0, 0, 0], [2054, 2057, 2031, 2182, 1037, 2143, 3819, 2005, 3087, 2008, 17257, 1999, 1996, 2088, 0, 0, 0, 0, 0, 0], [2054, 2057, 2031, 2182, 2003, 2143, 3819, 2005, 3087, 2008, 17257, 1999, 1996, 2088, 1997, 0, 0, 0, 0, 0], [2054, 2057, 2031, 2182, 2003, 1037, 3819, 2005, 3087, 2008, 17257, 1999, 1996, 2088, 1997, 2695, 0, 0, 0, 0], [2054, 2057, 2031, 2182, 2003, 1037, 2143, 2005, 3087, 2008, 17257, 1999, 1996, 2088, 1997, 2695, 1011, 0, 0, 0], [2054, 2057, 2031, 2182, 2003, 1037, 2143, 3819, 3087, 2008, 17257, 1999, 1996

In [122]:
class DocumentDataset(Dataset):
    def __init__(self, words, contexts):
        self.words = words
        self.contexts = contexts

    def __len__(self):
        return len(self.words)

    def __getitem__(self, idx):
        return (torch.tensor(self.contexts[idx]), torch.tensor(self.words[idx]))


train_set = DocumentDataset(train_words, train_contexts)
val_set = DocumentDataset(val_words, val_contexts)

In [123]:
def collate_fn(batch, K, vocab):
    positive_context_ids, word_ids = zip(*batch)

    # Convert to torch tensors
    word_ids = torch.tensor(word_ids)
    positive_context_ids = torch.stack(
        [torch.tensor(ids) for ids in positive_context_ids]
    )

    # Generate negative samples
    batch_size, context_size = positive_context_ids.shape
    vocab_size = len(vocab)
    negative_context_ids = torch.randint(
        0, vocab_size, (batch_size, K * context_size), dtype=torch.long
    )

    return {
        "word_ids": word_ids,
        "positive_context_ids": positive_context_ids,
        "negative_context_ids": negative_context_ids,
    }

In [124]:
# Some Parameters
K = 5
batch_size = 256

train_dataloader = DataLoader(
    dataset=train_set,
    batch_size=batch_size,
    collate_fn=lambda batch: collate_fn(batch, K=K, vocab=tokenizer.vocab),
)
val_data_loader = DataLoader(
    dataset=val_set,
    batch_size=batch_size,
    collate_fn=lambda batch: collate_fn(batch, K=K, vocab=tokenizer.vocab),
)

In [125]:
# Display a few batches
iterations = 0
print(f"R={R}, K={K}, batch_size={batch_size}")

for batch in train_dataloader:
    if iterations > 2:
        break

    print(
        f"""- Batch {iterations}
    words_ids: {batch['word_ids'].shape}, 
    positive_context_ids: {batch['positive_context_ids'].shape}, 
    negative_context_ids: {batch['negative_context_ids'].shape}"""
    )
    
    iterations += 1

R=10, K=5, batch_size=256
- Batch 0
    words_ids: torch.Size([256]), 
    positive_context_ids: torch.Size([256, 20]), 
    negative_context_ids: torch.Size([256, 100])
- Batch 1
    words_ids: torch.Size([256]), 
    positive_context_ids: torch.Size([256, 20]), 
    negative_context_ids: torch.Size([256, 100])
- Batch 2
    words_ids: torch.Size([256]), 
    positive_context_ids: torch.Size([256, 20]), 
    negative_context_ids: torch.Size([256, 100])


  [torch.tensor(ids) for ids in positive_context_ids]


In [126]:
# Word2Vec Model
class Word2Vec(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(Word2Vec, self).__init__()
        # We create two embedding layers and we set the padding_idx to 0
        self.word_embed = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.context_embed = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)

    def forward(self, word_ids, context_ids):
        w = self.word_embed(word_ids)
        C = self.context_embed(context_ids)
        dot_product = torch.bmm(w.unsqueeze(1), C.transpose(1, 2)).squeeze(1)
        return dot_product

In [127]:
"""
We redo some of the previous steps to properly parametrized our training function,
such as the dataloaders
"""


def training(
    model,
    document_train_set,
    document_valid_set,
    learning_rate,
    device,
    tokenizer,
    B,
    E,
    K,
    R,
):

    # Model to device
    model.to(device)

    # Optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Loading the data
    print("Loading the data...")
    train_contexts, train_words = flatten_dataset_to_lists(document_train_set, R)
    val_contexts, val_words = flatten_dataset_to_lists(document_valid_set, R)
    train_set = DocumentDataset(train_words, train_contexts)
    val_set = DocumentDataset(val_words, val_contexts)
    train_dataloader = DataLoader(
        dataset=train_set,
        batch_size=B,
        collate_fn=lambda batch: collate_fn(batch, K=K, vocab=tokenizer.vocab),
    )
    val_data_loader = DataLoader(
        dataset=val_set,
        batch_size=B,
        collate_fn=lambda batch: collate_fn(batch, K=K, vocab=tokenizer.vocab),
    )

    print(f"Number of Parameters: {sum(p.numel() for p in model.parameters())}")

    # Training the model
    model.train()
    for epoch in range(E):
        total_loss = 0
        for batch in tqdm(train_dataloader, desc="Training"):
            word_ids = batch["word_ids"].to(device)
            positive_context_ids = batch["positive_context_ids"].to(device)
            negative_context_ids = batch["negative_context_ids"].to(device)

            # Outputs of the model
            logits_pos = model(word_ids, positive_context_ids)
            logits_neg = model(word_ids, negative_context_ids)

            # Loss
            loss_pos = F.logsigmoid(logits_pos).mean()
            loss_neg = F.logsigmoid(-logits_neg).mean()
            loss = -loss_pos - loss_neg

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        # Testing the model on the validation set
        model.eval()
        correct_predictions = 0
        total_predictions = 0
        with torch.no_grad():
            for batch in val_data_loader:
                word_ids = batch["word_ids"].to(device)
                positive_context_ids = batch["positive_context_ids"].to(device)
                negative_context_ids = batch["negative_context_ids"].to(device)

                # Outputs of the model
                logits_pos = model(word_ids, positive_context_ids)
                logits_neg = model(word_ids, negative_context_ids)

                # Accuracy, we don't count the padding tokens
                correct_predictions += (
                    ((F.sigmoid(logits_pos) > 0.5) & (positive_context_ids != 0))
                    .sum()
                    .item()
                )
                correct_predictions += (
                    ((F.sigmoid(logits_neg) < 0.5) & (negative_context_ids != 0))
                    .sum()
                    .item()
                )
                total_pos = (positive_context_ids != 0).sum().item()
                total_neg = (negative_context_ids != 0).sum().item()
                total_predictions += total_pos + total_neg

            accuracy = correct_predictions / total_predictions

        print(
            f"Epoch {epoch + 1} | Training Loss: {total_loss/len(train_dataloader)}, Test Accuracy: {accuracy:.4f}"
        )

    print("Training completed!")

In [128]:
def save_model(model, model_path):
    torch.save({
        'word_embed': model.word_embed.weight.data.clone(),  # Save word embeddings
        'context_embed': model.context_embed.weight.data.clone()  # Save context embeddings
    }, model_path)

In [129]:
# Hyperparameters
device = "cuda" if torch.cuda.is_available() else "cpu"
learning_rate = 0.001
embedding_dim = 100
B = 256  # Batch size
E = 5  # Number of epochs
K = 2  # Factor for negative samples
R = 10  # Context window size

# Model
model = Word2Vec(vocab_size=tokenizer.vocab_size, embedding_dim=embedding_dim)

# We can now train the model and modify the hyperparameters as we want


training(
    model,
    document_train_set,
    document_valid_set,
    learning_rate=learning_rate,
    device=device,
    tokenizer=tokenizer,
    B=B,
    E=E,
    K=K,
    R=R,
)

# Save the model
file_name = f"model_dim-{embedding_dim}_radius-{R}_ratio-{K}-batch-{B}-epoch-{E}.ckpt"
save_model(model, file_name)
print(f"Model saved as {file_name}")

Loading the data...
Number of Parameters: 6104400


  [torch.tensor(ids) for ids in positive_context_ids]
Training: 100%|██████████| 3631/3631 [02:17<00:00, 26.37it/s]


Epoch 1 | Training Loss: 3.915961238379913, Test Accuracy: 0.7876


Training: 100%|██████████| 3631/3631 [02:15<00:00, 26.73it/s]


Epoch 2 | Training Loss: 1.4671830851753958, Test Accuracy: 0.8541


Training: 100%|██████████| 3631/3631 [02:15<00:00, 26.82it/s]


Epoch 3 | Training Loss: 1.0018951136262841, Test Accuracy: 0.8739


Training: 100%|██████████| 3631/3631 [02:14<00:00, 27.08it/s]


Epoch 4 | Training Loss: 0.8120698052409988, Test Accuracy: 0.8830


Training: 100%|██████████| 3631/3631 [02:14<00:00, 27.00it/s]


Epoch 5 | Training Loss: 0.7084791955869738, Test Accuracy: 0.8877
Training completed!
Model saved as model_dim-100_radius-10_ratio-2-batch-256-epoch-5.ckpt
