In [1]:
%load_ext autoreload
%autoreload 2

import sys
import os
import dotenv
from pathlib import Path

env_file = "../.env"

if os.path.exists(env_file):
    dotenv.load_dotenv(env_file, verbose=True)
    print("Loaded environment variables from .env file.")

cwd = os.getcwd()
# for some reason appending to PATH you need it to be string
sys.path.append(str(Path(cwd).parent / "src"))
hf_access_token = os.getenv("HUGGINGFACE_API_KEY")

Loaded environment variables from .env file.


In [3]:
from unlearn_order.dataset import load_dataset
data_dir = Path(cwd).parent / "data" / "random_bd"

files = ["split_0.jsonl"]
dataset = load_dataset(data_dir, files)

In [4]:
# i should prepare context vectors
from typing import Dict
from transformers import LlamaTokenizer
from functools import partial
from copy import deepcopy
import torch
from transformers import AutoTokenizer
from unlearn_order.utils import create_prompt, create_prompt_letter_answer, Point

def map_fn(data: Point, tokenizer: LlamaTokenizer):
    context = create_prompt(data)
    context_masks = []
    completion_masks = []
    input_ids = []

    for idx, choice in enumerate(data["choices"]):
        new_data = deepcopy(data)
        new_data["answer"] = idx
        completion = create_prompt_letter_answer(new_data)
        context_ids = tokenizer.encode(context, return_tensors="pt")
        input_id = tokenizer.encode(completion, return_tensors="pt")

        context_ids = context_ids[0]
        input_id = input_id[0]
        
        all_ids = torch.cat([context_ids, input_id])
        context_mask = torch.zeros_like(all_ids)
        context_mask[:context_ids.size(0)] = 1

        completion_mask = torch.zeros_like(all_ids)
        completion_mask[context_ids.size(0):] = 1

        context_masks.append(context_mask)
        completion_masks.append(completion_mask)
        input_ids.append(all_ids)

    return {
        "context_masks": context_masks,
        "completion_masks": completion_masks,
        "input_ids": input_ids,
        "answer": data["answer"]
    }

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
dataset = dataset.map(partial(map_fn, tokenizer=tokenizer))



Map:   0%|          | 0/157 [00:00<?, ? examples/s]

In [4]:
PAD_TOKEN_ID = tokenizer.eos_token_id
print(PAD_TOKEN_ID)

from typing import List
import torch.nn.functional as F
# byte length normalization


    

def batch_force(batch: Dict[str, torch.Tensor], max_length: int = 128, pad_token_dict: Dict[str, int] = {
    "completion_masks": 0,
    "context_masks": 0,
    "input_ids": PAD_TOKEN_ID,
}):
    # first dimension of each tensor is the batch size
    # batch[key] = [batch_size, seq_len]
    # truncate to max length
    keys = list(pad_token_dict.keys())
    answers = []
    n_choices = []
    new_batch = {}
    for samples in batch:
        n = 0
        for key in keys:
            if key not in new_batch:
                new_batch[key] = []
            for idx in range(len(samples[key])):
                new_batch[key].append(samples[key][idx][:max_length])
            n = len(samples[key])
        n_choices.append(n)
        answers.append(samples["answer"])
    

    batch = new_batch
    
    for key in keys:
        for idx in range(len(batch[key])):
            batch[key][idx] = batch[key][idx][:max_length]  
    
    # pad on the left
    attention_masks = []
    for key in keys:
        max_len = max(len(batch[key][idx]) for idx in range(len(batch[key])))
        
        for idx in range(len(batch[key])):
            padding_length = max_len - len(batch[key][idx])
            attention_mask = torch.tensor([0] * padding_length + [1] * len(batch[key][idx]))
            batch[key][idx] =  torch.cat([torch.tensor([
                pad_token_dict[key]
            ] * padding_length, dtype=torch.int64), torch.tensor(batch[key][idx])])

            # hack because i only want one attention mask for sample, not one per key
            if key == keys[0]:
                attention_masks.append(attention_mask)
    return {
        "context_masks": torch.stack(batch["context_masks"]),
        "completion_masks": torch.stack(batch["completion_masks"]),
        "attention_masks": torch.stack(attention_masks),
        "input_ids": torch.stack(batch["input_ids"]),
        "answers": torch.tensor(answers),
        "n_choices": torch.tensor(n_choices)
    }
    

    
def collate_fn(batch, max_length=128):
    # Get the token ids and find the maximum length in the current batc
    return batch_force(batch, max_length=max_length)


from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=4, collate_fn=collate_fn)

128009


In [5]:
for batch in dataloader:
    for k,v in batch.items():
        print(k, v.shape)
    print(batch["input_ids"][-5])
    break

context_masks torch.Size([16, 73])
completion_masks torch.Size([16, 73])
attention_masks torch.Size([16, 73])
input_ids torch.Size([16, 73])
answers torch.Size([4])
n_choices torch.Size([4])
tensor([128009, 128009, 128000,   4599,    574,  40139,  44707,   9405,   5380,
            32,     13,    220,   2550,     20,    198,     33,     13,    220,
          1049,     20,    198,     34,     13,    220,   2366,     17,    198,
            35,     13,    220,   4468,     22,    198,  16533,     25, 128000,
          4599,    574,  40139,  44707,   9405,   5380,     32,     13,    220,
          2550,     20,    198,     33,     13,    220,   1049,     20,    198,
            34,     13,    220,   2366,     17,    198,     35,     13,    220,
          4468,     22,    198,  16533,     25,    423,     13,    220,   4468,
            22])


In [6]:
from transformers import AutoModelForCausalLM
from research_tools.gpu import get_gpus_available

os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in get_gpus_available()])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = AutoModelForCausalLM.from_pretrained(model_id).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [23]:

def compute_logits_loss(logits, labels, completion_masks):
    # logits = [batch_size, seq_len, vocab_size]
    # labels = [batch_size, seq_len]
    # completion_masks = [batch_size, seq_len]
    # mask out the padding tokens
    # print(labels.shape, logits.shape, completion_masks.shape)
    # shift the labels to the right
    labels = labels[..., 1:]
    logits = logits[..., :-1, :]
    batch_size = labels.size(0)
    seq_len = labels.size(1)

    logits = logits.reshape((batch_size * seq_len, -1))
    labels = labels.reshape((batch_size * seq_len))

    ce_loss = F.cross_entropy(logits, labels, reduction="none")
    ce_loss = ce_loss.view(batch_size, seq_len)
    ce_loss = ce_loss * completion_masks[..., 1:]
    normalization = completion_masks[..., 1:].sum(axis=1)
    ce_loss = ce_loss.sum(axis=1) 
    
    total_tokens = normalization.sum()

    loss = ce_loss.sum() / total_tokens
    return loss


def compute_byte_normalized_loss(input_ids, logits, labels, completion_masks, tokenizer):
    # logits = [batch_size, seq_len, vocab_size]
    # labels = [batch_size, seq_len]
    # completion_masks = [batch_size, seq_len]
    # mask out the padding tokens
    # print(labels.shape, logits.shape, completion_masks.shape)
    # shift the labels to the right
    batch_size = labels.size(0)
    byte_lengths = []
    
    for idx in range(batch_size):
        completion = tokenizer.decode(input_ids[idx][completion_masks[idx] == 1])
        length = len(completion.encode("utf-8"))
        byte_lengths.append(length)
    
    # get utf-8 byte lengths
    labels = labels[..., 1:]
    logits = logits[..., :-1, :]

    batch_size = labels.size(0)
    seq_len = labels.size(1)

    logits = logits.reshape((batch_size * seq_len, -1))
    labels = labels.reshape((batch_size * seq_len))

    ce_loss = F.cross_entropy(logits, labels, reduction="none")
    ce_loss = ce_loss.view(batch_size, seq_len)
    ce_loss = ce_loss * completion_masks[..., 1:]
    ce_loss = ce_loss.sum(axis=1) 
    ce_loss = ce_loss / torch.tensor(byte_lengths, dtype=torch.float32, device=device)
    return ce_loss

for batch in dataloader:
    input_ids = batch["input_ids"].to(device)
    attention_masks = batch["attention_masks"].to(device)
    completion_masks = batch["completion_masks"].to(device)
    answers = batch["answers"].to(device)
    n_choices = batch["n_choices"].to(device)

    # take the inds for the correct answer choice
    take_inds = n_choices.cumsum(axis=0)[:-1]
    take_inds = torch.cat([torch.tensor([0], device=device), take_inds])
    take_inds = take_inds + answers

    input_ids = input_ids[take_inds]
    attention_masks = attention_masks[take_inds]
    completion_masks = completion_masks[take_inds]

    labels = input_ids.clone()
    labels[completion_masks == 0] = -100
    outputs = model(input_ids, attention_mask=attention_masks, labels=labels)
    loss = outputs.loss
    print("HI")
    cross_entropy = compute_logits_loss(outputs.logits, labels, completion_masks)
    other_loss = compute_byte_normalized_loss(input_ids, outputs.logits, labels, completion_masks, tokenizer)
    print(loss, cross_entropy)
    print(other_loss)
    break


HI
[88, 91, 91, 90]
tensor(1.1299, device='cuda:0', grad_fn=<NllLossBackward0>) tensor(1.1299, device='cuda:0', grad_fn=<DivBackward0>)
tensor([0.4854, 0.4425, 0.5307, 0.4874], device='cuda:0',
       grad_fn=<DivBackward0>)


In [None]:
# corpus loss is just take union of context+completions
