In [18]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [19]:
seed = 42

In [20]:
from transformers import LlamaModel, AutoTokenizer
from datasets import load_dataset


model = LlamaModel.from_pretrained(
    "TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T"
)
tokenizer = AutoTokenizer.from_pretrained(
    "TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T"
)

base_dataset = load_dataset(
    "andersonbcdefg/synthetic_tuples_gpt35_turbo", split="train"
)

In [21]:
from peft import LoraConfig, TaskType
from transformers import TrainingArguments, IntervalStrategy


# consider and experiment withadding a specific pad token
# tokenizer.add_special_tokens({'pad_token': '[PAD]'})
tokenizer.pad_token = tokenizer.unk_token
tokenizer.pad_token_id = tokenizer.unk_token_id

# Tokenize the dataset
def tokenize_function(examples):
    max_len = 128 # TODO - reconsider this
    tokenized_query = tokenizer(examples["query"], padding=True, truncation=True, max_length=max_len, return_tensors="pt")
    tokenized_pos = tokenizer(examples["pos"], padding=True,  truncation=True, max_length=max_len, return_tensors="pt")
    tokenized_neg = tokenizer(examples["neg"], padding=True, truncation=True, max_length=max_len, return_tensors="pt")
    return {
        "input_ids_query": tokenized_query["input_ids"],
        "attention_mask_query": tokenized_query["attention_mask"],
        "input_ids_pos": tokenized_pos["input_ids"],
        "attention_mask_pos": tokenized_pos["attention_mask"],
        "input_ids_neg": tokenized_neg["input_ids"],
        "attention_mask_neg": tokenized_neg["attention_mask"],
    }


dataset = base_dataset.map(tokenize_function, batched=True, cache_file_name="./cache/tokenized_datasets")

train_test_split = dataset.train_test_split(test_size=0.2, seed=seed)

# Access the new train and test datasets
train_dataset = train_test_split["train"]
eval_dataset = train_test_split["test"]

In [22]:
peft_config = LoraConfig(
    task_type=TaskType.FEATURE_EXTRACTION,
    r=32,
    lora_alpha=64,
    lora_dropout=0.1,
    target_modules=[
        "q_proj",
        "v_proj",
        "o_proj",
        "down_proj",
        "up_proj",
        "gate_proj",
    ],
    inference_mode=False,
)

model.add_adapter(peft_config)

In [41]:
from transformers import Trainer
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from datasets import Dataset


class TinyEmbedTrainer(Trainer):
    def __init__(
        self,
        model,
        args: TrainingArguments,
        train_dataset: Dataset,
        eval_dataset: Dataset,
        tokenizer: AutoTokenizer,
    ):
        # Consider reworking the model's signature to conform to training expectations
        args.remove_unused_columns = False
        super().__init__(
            model,
            args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            tokenizer=tokenizer,
            data_collator=self._map_collate_fn,
        )
        # self.train_dataset = train_dataset
        # self.eval_dataset = eval_dataset
        # self.tokenizer = tokenizer

    def _map_collate_fn(self, batch):
        # Implement your custom collate logic here
        # TODO - truncate to max length for each of query, pos, neg
        def get_field(field: str):
            attention = [item["attention_mask_" + field] for item in batch]
            inputs = [item["input_ids_" + field] for item in batch]
            # probably something wrong at the dataset level... but munge things into the right shape here,
            # truncate the batch to the max length, but also 0 pad those that are less than the max length
            # for both input_ids and attention_mask
            max_len = max([sum(x) for x in attention])
            attention_mask = []
            input_ids = []
            for i in range(len(attention)):
                attention_i = attention[i]
                input_ids_i = inputs[i]
                attention_trunc = attention_i[:max_len]
                input_ids_trunc = input_ids_i[:max_len]
                attention_pad = attention_trunc + [0] * (max_len - len(attention_trunc))
                input_ids_pad = input_ids_trunc + [self.tokenizer.pad_token_id] * (
                    max_len - len(input_ids_trunc)
                )
                attention_mask.append(attention_pad)
                input_ids.append(input_ids_pad)

            return {
                ("attention_mask_" + field): torch.tensor(attention_mask),
                ("input_ids_" + field): torch.tensor(input_ids),
            }
        return {
            **get_field("query"), **get_field("pos"), **get_field("neg")
        }

    def compute_loss(self, model, inputs, return_outputs=False):
        # print(inputs)
        # print(model)
        query_inputs = {"input_ids": inputs["input_ids_query"], "attention_mask": inputs["attention_mask_query"]}
        pos_inputs = {"input_ids": inputs["input_ids_pos"], "attention_mask": inputs["attention_mask_pos"]}
        neg_inputs = {"input_ids": inputs["input_ids_neg"], "attention_mask": inputs["attention_mask_neg"]}
        outputs_query = model(**query_inputs, output_hidden_states=True)
        outputs_pos = model(**pos_inputs, output_hidden_states=True)
        outputs_neg = model(**neg_inputs, output_hidden_states=True)
        raise NotImplementedError("I am but a poor boy, from a poor family. Scalamoose")
        outputs = model(**inputs, output_hidden_states=True)
        hidden_states = outputs.hidden_states[-1]  # Get the last hidden states

        # Compute the indices of the last non-padding tokens
        attention_mask = inputs["attention_mask"]
        sequence_lengths = attention_mask.sum(dim=1)
        last_token_indices = sequence_lengths - 1

        # Retrieve the embeddings for each example in the batch
        query_embeddings = hidden_states[
            torch.arange(hidden_states.size(0)), last_token_indices, :
        ]

        # Assuming that pos and neg inputs are also passed in the same manner
        # Replace 'pos_input' and 'neg_input' with the actual input names
        pos_outputs = model(**pos_input, output_hidden_states=True)
        neg_outputs = model(**neg_input, output_hidden_states=True)

        pos_hidden_states = pos_outputs.hidden_states[-1]
        neg_hidden_states = neg_outputs.hidden_states[-1]

        pos_embeddings = pos_hidden_states[
            torch.arange(pos_hidden_states.size(0)), last_token_indices, :
        ]
        neg_embeddings = neg_hidden_states[
            torch.arange(neg_hidden_states.size(0)), last_token_indices, :
        ]

        # Normalize the embeddings
        query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
        pos_embeddings = F.normalize(pos_embeddings, p=2, dim=1)
        neg_embeddings = F.normalize(neg_embeddings, p=2, dim=1)

        # Compute InfoNCE loss
        pos_similarity = torch.sum(query_embeddings * pos_embeddings, dim=1)
        neg_similarity = torch.sum(query_embeddings * neg_embeddings, dim=1)
        losses = -torch.log(
            torch.exp(pos_similarity)
            / (torch.exp(pos_similarity) + torch.exp(neg_similarity))
        )
        return (losses.mean(), outputs) if return_outputs else losses.mean()


# Training arguments
training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=1,
    num_train_epochs=3,
    logging_dir="./logs",
    logging_strategy=IntervalStrategy.STEPS,
    logging_steps=50,
)

trainer = TinyEmbedTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
)
trainer.train()

<torch.utils.data.dataloader._SingleProcessDataLoaderIter object at 0x2a6e25e40>
<torch.utils.data.dataloader._SingleProcessDataLoaderIter object at 0x2a6e25e40>
BaseModelOutputWithPast(last_hidden_state=tensor([[[ -0.7818,   0.7505,  -0.8306,  ...,   1.9186,  -0.5190,   0.0484],
         [ -1.8923,  -0.1261,  -0.3167,  ...,  -0.1625,   1.6862,  -1.6313],
         [ -0.3130,   0.4132,   0.8534,  ..., -10.6115,   0.8348,  -0.7307],
         ...,
         [  0.5226,   0.1021,   0.9319,  ...,  -2.6739,   1.3098,   1.5813],
         [ -1.4857,   0.5001,   0.4461,  ...,  -2.4791,   1.1859,  -0.0396],
         [ -0.6583,   0.5537,   0.0921,  ...,  -2.7668,   1.4957,   1.1417]]],
       device='mps:0', grad_fn=<MulBackward0>), past_key_values=((tensor([[[[ 3.8590e-02, -1.1799e-01,  1.4089e-02,  ..., -1.2343e-02,
            2.3950e-02, -3.4950e-03],
          [ 4.1151e-01,  2.1022e-01,  1.3271e-01,  ..., -7.7221e-02,
           -1.2253e-02, -6.1907e-02],
          [ 6.3104e-01,  2.8880e-01,  

NotImplementedError: I am but a poor boy, from a poor family. Scalamoose