In [None]:
from neural_cherche import models, utils, train, losses
import torch
from pathlib import Path
import json
from time import time
from datasets import Dataset, load_dataset

model_name = "raphaelsty/neural-cherche-sparse-embed"

# Check that MPS is available
if not torch.backends.mps.is_available():
    if not torch.backends.mps.is_built():
        print("MPS not available because the current PyTorch install was not "
              "built with MPS enabled.")
    else:
        print("MPS not available because the current MacOS version is not 12.3+ "
              "and/or you do not have an MPS-enabled device on this machine.")

In [None]:
model = models.SparseEmbed(
    model_name_or_path=model_name,
    device="mps" if torch.backends.mps.is_available() else "cpu",
)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-6)
flops_scheduler = losses.FlopsScheduler()

In [None]:
def triplet_to_tuple_input(triple: dict):
    keys = triple['metadata']['objective']['triplet'][0]
    return (
        triple[keys[0]],
        triple[keys[1]],
        triple[keys[2]],
    )

In [None]:
Path.ls = lambda x: list(x.iterdir())
nli_triplets = Path("../nomic-triplets/nli-triplets")
nli_triplets.ls()

triplets = []
for file in nli_triplets.ls():
    if "shard" in file.name:
        with file.open("r") as f:
            file_content = f.readlines()
            triplets.extend([triplet_to_tuple_input(json.loads(line)) for line in file_content])

In [None]:
# ds = Dataset.from_list(triplets)
# ds.push_to_hub("NirantK/nli-triplets", token="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")

In [None]:
import json
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_name)

def display_tokens_and_weights(sparse_embedding, tokenizer):
    token_weight_dict = {}
    for i in range(len(sparse_embedding.indices)):
        token = tokenizer.decode([sparse_embedding.indices[i]])
        weight = sparse_embedding.values[i]
        token_weight_dict[token] = weight

    # Sort the dictionary by weights
    token_weight_dict = dict(sorted(token_weight_dict.items(), key=lambda item: item[1], reverse=True))
    print(json.dumps(token_weight_dict, indent=4))
    return token_weight_dict

In [9]:
for step, (anchor, positive, negative) in enumerate(utils.iter(
        triplets,
        epochs=2,
        batch_size=16,
        shuffle=True
    )):

    loss = train.train_sparse_embed(
        model=model,
        optimizer=optimizer,
        anchor=anchor,
        positive=positive,
        negative=negative,
        threshold_flops=30,
        flops_loss_weight=flops_scheduler.get(),
        step=step,
        gradient_accumulation_steps=50,
    )

    if (step + 1) % 50 == 0:
        # Save the model every 1000 steps
        model.save_pretrained("checkpoint")
        checkpoint = models.SparseEmbed(
            model_name_or_path="checkpoint",
            device="mps" if torch.backends.mps.is_available() else "cpu",
        )
        time_now = time()
        query_activations = checkpoint.encode(["Hello World from Qdrant!"], query_mode=True)
        encode_time = time() - time_now
        print(f"Query Encode time: {encode_time}")
        print(tokenizer.decode(query_activations['activations'][0]))
        time_now = time()
        document_activations = checkpoint.encode(["Hello World from Qdrant!"], query_mode=False)
        encode_time = time() - time_now
        print(f"Document Encode time: {encode_time}")
        print(tokenizer.decode(document_activations['activations'][0]))

Epoch 0:   0%|          | 62/17321 [03:56<18:15:10,  3.81s/it]


KeyboardInterrupt: 