In [1]:
from neural_cherche import models, utils, train, losses
import torch
from pathlib import Path
import json

model_name = "raphaelsty/neural-cherche-sparse-embed"

# Check that MPS is available
if not torch.backends.mps.is_available():
    if not torch.backends.mps.is_built():
        print("MPS not available because the current PyTorch install was not "
              "built with MPS enabled.")
    else:
        print("MPS not available because the current MacOS version is not 12.3+ "
              "and/or you do not have an MPS-enabled device on this machine.")

In [2]:
model = models.SparseEmbed(
    model_name_or_path=model_name,
    device="mps" if torch.backends.mps.is_available() else "cpu",
)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-6)
flops_scheduler = losses.FlopsScheduler()

In [3]:
def triplet_to_tuple_input(triple: dict):
    keys = triple['metadata']['objective']['triplet'][0]
    return (
        triple[keys[0]],
        triple[keys[1]],
        triple[keys[2]],
    )

In [4]:
Path.ls = lambda x: list(x.iterdir())
nli_triplets = Path("../nomic-triplets/nli-triplets")
nli_triplets.ls()

triplets = []
for file in nli_triplets.ls():
    if "shard" in file.name:
        with file.open("r") as f:
            file_content = f.readlines()
            triplets.extend([triplet_to_tuple_input(json.loads(line)) for line in file_content])

In [5]:
import json
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_name)

def display_tokens_and_weights(sparse_embedding, tokenizer):
    token_weight_dict = {}
    for i in range(len(sparse_embedding.indices)):
        token = tokenizer.decode([sparse_embedding.indices[i]])
        weight = sparse_embedding.values[i]
        token_weight_dict[token] = weight

    # Sort the dictionary by weights
    token_weight_dict = dict(sorted(token_weight_dict.items(), key=lambda item: item[1], reverse=True))
    print(json.dumps(token_weight_dict, indent=4))
    return token_weight_dict

tokenizer_config.json:   0%|          | 0.00/1.26k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/712k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

In [6]:
for step, (anchor, positive, negative) in enumerate(utils.iter(
        triplets,
        epochs=2,
        batch_size=16,
        shuffle=True
    )):

    loss = train.train_sparse_embed(
        model=model,
        optimizer=optimizer,
        anchor=anchor,
        positive=positive,
        negative=negative,
        threshold_flops=30,
        flops_loss_weight=flops_scheduler.get(),
        step=step,
        gradient_accumulation_steps=50,
    )

    if (step + 1) % 10 == 0:
        # Save the model every 1000 steps
        model.save_pretrained("checkpoint")
        checkpoint = models.SparseEmbed(
            model_name_or_path="checkpoint",
            device="mps" if torch.backends.mps.is_available() else "cpu",
        )
        query_activations = checkpoint.encode(["Hello World!"], query_mode=True)
        print(tokenizer.decode(query_activations['activations'][0]))
        document_activations = checkpoint.encode(["Hello World!"], query_mode=False)
        print(tokenizer.decode(document_activations['activations'][0]))

Epoch 0:   0%|          | 10/17321 [00:28<15:08:38,  3.15s/it]

hello world!world hi hawkins stanley roger spencer worlds mia simon nathan lucy happiness germany marcus reed company england arnold street fred nelson jay berlin global japan museum thomas winston jerry hotel ross israel united [PAD] [unused0] [unused1] [unused2] [unused3] [unused4] [unused5] [unused6] [unused7] [unused8] [unused9] [unused10] [unused11] [unused12] [unused13] [unused14] [unused15] [unused16] [unused17] [unused18] [unused19] [unused20] [unused21] [unused22] [unused23] [unused24] [unused25] [unused26]
hello world!world hi hawkins stanley roger spencer worlds mia simon nathan lucy happiness germany marcus reed company england arnold street fred nelson jay berlin global japan museum thomas winston jerry hotel ross israel united [PAD] [unused0] [unused1] [unused2] [unused3] [unused4] [unused5] [unused6] [unused7] [unused8] [unused9] [unused10] [unused11] [unused12] [unused13] [unused14] [unused15] [unused16] [unused17] [unused18] [unused19] [unused20] [unused21] [unused22] 

Epoch 0:   0%|          | 20/17321 [00:55<14:11:21,  2.95s/it]

hello world!world hi hawkins stanley roger spencer worlds mia simon nathan lucy happiness germany marcus reed company england arnold street fred nelson jay berlin global japan museum thomas winston jerry hotel ross israel united [PAD] [unused0] [unused1] [unused2] [unused3] [unused4] [unused5] [unused6] [unused7] [unused8] [unused9] [unused10] [unused11] [unused12] [unused13] [unused14] [unused15] [unused16] [unused17] [unused18] [unused19] [unused20] [unused21] [unused22] [unused23] [unused24] [unused25] [unused26]
hello world!world hi hawkins stanley roger spencer worlds mia simon nathan lucy happiness germany marcus reed company england arnold street fred nelson jay berlin global japan museum thomas winston jerry hotel ross israel united [PAD] [unused0] [unused1] [unused2] [unused3] [unused4] [unused5] [unused6] [unused7] [unused8] [unused9] [unused10] [unused11] [unused12] [unused13] [unused14] [unused15] [unused16] [unused17] [unused18] [unused19] [unused20] [unused21] [unused22] 

Epoch 0:   0%|          | 24/17321 [01:08<13:37:21,  2.84s/it]


KeyboardInterrupt: 